Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save AnnoyingTechnology/8349bb5a7855cd852b0a233b416344d4 to your computer and use it in GitHub Desktop.
Save AnnoyingTechnology/8349bb5a7855cd852b0a233b416344d4 to your computer and use it in GitHub Desktop.
config.ts (Bedrock Knowledge Base provider, for Continue.dev)
// the reranking is broken
// having to use aws-cli instead of the sdk is really painful
import { promises as fs } from "fs";
import * as path from "path";
import { tmpdir } from "os";
// If you need randomUUID, Node v14+ required.
// import { randomUUID } from "crypto";
type KnowledgeBaseConfig = {
knowledgeBaseId: string;
numberOfResults: number;
// Uncomment if you actually have guardrails
// guardrailsId?: string;
// guardrailsVersion?: string;
};
const BedrockKBContextProvider = {
title: "REDACTED-kb",
displayTitle: "REDACTED's knowledge bases",
description: "Query REDACTED's knowledge bases for relevant information",
renderInlineAs: "",
type: "query" as const,
kbConfigs: <KnowledgeBaseConfig[]>[
{ knowledgeBaseId: "_", numberOfResults: 48 },
{ knowledgeBaseId: "_", numberOfResults: 48 },
{ knowledgeBaseId: "_", numberOfResults: 48 },
{ knowledgeBaseId: "_", numberOfResults: 48 },
],
region: "eu-central-1",
escapeShellArg(arg: string): string {
return `'${arg.replace(/'/g, "'\\''")}'`;
},
// Simple helper to write JSON to a temp file, returning a file:// path
async writeTempJsonFile(content: any): Promise<string> {
// If you don't have randomUUID, you can do:
// const fileName = path.join(tmpdir(), `bedrock-payload-${Date.now()}-${Math.random()}.json`);
// Or if you do have randomUUID:
const fileName = path.join(tmpdir(), `bedrock-payload-${Date.now()}.json`);
await fs.writeFile(fileName, JSON.stringify(content), "utf8");
return `file://${fileName}`;
},
async getContextItems(
query: string | null,
extras: ContextProviderExtras
): Promise<ContextItem[]> {
console.log("[BedrockKB] GetContextItems called");
const actualQuery = extras.fullInput?.replace(this.displayTitle, "").trim();
if (!actualQuery) {
console.log(
"[BedrockKB] No valid query found, skipping knowledge base search"
);
return [];
}
try {
console.log("[BedrockKB] Checking AWS credentials...");
const stsCommand = [
"aws",
"sts",
"get-caller-identity",
"--profile",
"bedrock",
].join(" ");
const [identityOutput, identityError] = await extras.ide.subprocess(
stsCommand
);
if (!identityOutput) {
throw new Error(
"Failed to get AWS credentials. Please ensure you're logged in with 'aws sso login --profile bedrock'"
);
}
const allResults = await Promise.all(
this.kbConfigs.map(async (kbConfig) => {
try {
console.log(
`[BedrockKB] Querying KB: ${kbConfig.knowledgeBaseId} => '${actualQuery}'`
);
// Prepare JSON payloads as files
const retrievalQueryFile = await this.writeTempJsonFile({
text: actualQuery,
});
const retrievalConfigFile = await this.writeTempJsonFile({
vectorSearchConfiguration: {
numberOfResults: kbConfig.numberOfResults || 32,
},
});
// No /dev/stdout here!
const commandParts = [
"aws bedrock-agent-runtime retrieve",
`--knowledge-base-id ${this.escapeShellArg(
kbConfig.knowledgeBaseId
)}`,
`--retrieval-query ${retrievalQueryFile}`,
`--retrieval-configuration ${retrievalConfigFile}`,
`--region ${this.escapeShellArg(this.region)}`,
"--profile bedrock",
];
// If you have guardrails, uncomment and define them in the type:
/*
if (kbConfig.guardrailsId) {
const guardrailsFile = await this.writeTempJsonFile({
guardrailsId: kbConfig.guardrailsId,
guardrailsVersion: kbConfig.guardrailsVersion
});
commandParts.push(`--guardrails-configuration ${guardrailsFile}`);
}
*/
const command = commandParts.join(" ");
console.log("[BedrockKB] Executing retrieve command:", command);
const [stdout, stderr] = await extras.ide.subprocess(command);
if (stderr) {
console.error(
`[BedrockKB] Error for KB ${kbConfig.knowledgeBaseId}:`,
stderr
);
return { kbId: kbConfig.knowledgeBaseId, results: [] };
}
const response = JSON.parse(stdout);
const results = response.retrievalResults || [];
const chunks = results.map((r: any) => ({
content: r.content?.text ?? "",
metadata: {
location: r.location || "Unknown location",
kbId: kbConfig.knowledgeBaseId,
},
}));
// Rerank if we have content
if (chunks.length > 0) {
try {
const rerankScores = await this.rerankResults(
actualQuery,
chunks,
extras
);
return {
kbId: kbConfig.knowledgeBaseId,
results: results.map((r: any, idx: number) => ({
...r,
score: rerankScores[idx] || 0,
})),
};
} catch (err) {
console.error("[BedrockKB] Reranking error:", err);
return { kbId: kbConfig.knowledgeBaseId, results };
}
}
return { kbId: kbConfig.knowledgeBaseId, results };
} catch (err) {
console.error(
`[BedrockKB] Error querying KB ${kbConfig.knowledgeBaseId}:`,
err
);
return { kbId: kbConfig.knowledgeBaseId, results: [] };
}
})
);
// Flatten into context items
const contextItems = allResults.flatMap(({ kbId, results }) =>
results
.filter((r: any) => r.content?.text)
.map((r: any) => ({
name: `${kbId} - ${r.location || "Unknown location"}`,
description: `Relevance score: ${r.score || 0}`,
content: r.content?.text ?? "",
}))
);
console.log("[BedrockKB] Final context items:", contextItems.length);
return contextItems;
} catch (err) {
console.error("[BedrockKB] Error in getContextItems:", err);
return [];
}
},
async rerankResults(
query: string,
chunks: any[],
extras: ContextProviderExtras
) {
try {
const payloadFile = await this.writeTempJsonFile({
query,
documents: chunks.map((c) => c.content),
});
// For bedrock-runtime invoke-model, you might also remove `/dev/stdout`
// if it complains about "Unknown options: /dev/stdout".
// But some versions require an outfile. If you see
// “the following arguments are required: outfile” then keep it.
const commandParts = [
"aws bedrock-runtime invoke-model",
`--model-id cohere.rerank-v3-5:0`,
"--accept application/json",
"--content-type application/json",
`--region ${this.escapeShellArg(this.region)}`,
"--profile bedrock",
`--body ${payloadFile}`,
// If the CLI complains about "Unknown options: /dev/stdout",
// remove this next line. If it complains about "outfile is required", keep it.
"/dev/stdout",
];
const command = commandParts.join(" ");
console.log("[BedrockKB] Reranking command:", command);
const [stdout, stderr] = await extras.ide.subprocess(command);
if (stderr) {
throw new Error(`Reranking error: ${stderr}`);
}
const response = JSON.parse(stdout);
// 'body' is typically base64-encoded JSON
const decoded = JSON.parse(
new TextDecoder().decode(Buffer.from(response.body))
);
// Example: { results: [ { index: i, relevance_score: number }, ... ] }
return decoded.results
.sort((a: any, b: any) => a.index - b.index)
.map((r: any) => r.relevance_score);
} catch (err) {
console.error("[BedrockKB] Reranking error:", err);
throw err;
}
},
};
export function modifyConfig(config: Config): Config {
console.log("[BedrockKB] modifyConfig called, adding provider");
if (!config.contextProviders) {
config.contextProviders = [];
}
config.contextProviders.push(BedrockKBContextProvider);
return config;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment