Created
December 27, 2024 13:13
-
-
Save AnnoyingTechnology/8349bb5a7855cd852b0a233b416344d4 to your computer and use it in GitHub Desktop.
config.ts (Bedrock Knowledge Base provider, for Continue.dev)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// the reranking is broken | |
// having to use aws-cli instead of the sdk is really painful | |
import { promises as fs } from "fs"; | |
import * as path from "path"; | |
import { tmpdir } from "os"; | |
// If you need randomUUID, Node v14+ required. | |
// import { randomUUID } from "crypto"; | |
type KnowledgeBaseConfig = { | |
knowledgeBaseId: string; | |
numberOfResults: number; | |
// Uncomment if you actually have guardrails | |
// guardrailsId?: string; | |
// guardrailsVersion?: string; | |
}; | |
const BedrockKBContextProvider = { | |
title: "REDACTED-kb", | |
displayTitle: "REDACTED's knowledge bases", | |
description: "Query REDACTED's knowledge bases for relevant information", | |
renderInlineAs: "", | |
type: "query" as const, | |
kbConfigs: <KnowledgeBaseConfig[]>[ | |
{ knowledgeBaseId: "_", numberOfResults: 48 }, | |
{ knowledgeBaseId: "_", numberOfResults: 48 }, | |
{ knowledgeBaseId: "_", numberOfResults: 48 }, | |
{ knowledgeBaseId: "_", numberOfResults: 48 }, | |
], | |
region: "eu-central-1", | |
escapeShellArg(arg: string): string { | |
return `'${arg.replace(/'/g, "'\\''")}'`; | |
}, | |
// Simple helper to write JSON to a temp file, returning a file:// path | |
async writeTempJsonFile(content: any): Promise<string> { | |
// If you don't have randomUUID, you can do: | |
// const fileName = path.join(tmpdir(), `bedrock-payload-${Date.now()}-${Math.random()}.json`); | |
// Or if you do have randomUUID: | |
const fileName = path.join(tmpdir(), `bedrock-payload-${Date.now()}.json`); | |
await fs.writeFile(fileName, JSON.stringify(content), "utf8"); | |
return `file://${fileName}`; | |
}, | |
async getContextItems( | |
query: string | null, | |
extras: ContextProviderExtras | |
): Promise<ContextItem[]> { | |
console.log("[BedrockKB] GetContextItems called"); | |
const actualQuery = extras.fullInput?.replace(this.displayTitle, "").trim(); | |
if (!actualQuery) { | |
console.log( | |
"[BedrockKB] No valid query found, skipping knowledge base search" | |
); | |
return []; | |
} | |
try { | |
console.log("[BedrockKB] Checking AWS credentials..."); | |
const stsCommand = [ | |
"aws", | |
"sts", | |
"get-caller-identity", | |
"--profile", | |
"bedrock", | |
].join(" "); | |
const [identityOutput, identityError] = await extras.ide.subprocess( | |
stsCommand | |
); | |
if (!identityOutput) { | |
throw new Error( | |
"Failed to get AWS credentials. Please ensure you're logged in with 'aws sso login --profile bedrock'" | |
); | |
} | |
const allResults = await Promise.all( | |
this.kbConfigs.map(async (kbConfig) => { | |
try { | |
console.log( | |
`[BedrockKB] Querying KB: ${kbConfig.knowledgeBaseId} => '${actualQuery}'` | |
); | |
// Prepare JSON payloads as files | |
const retrievalQueryFile = await this.writeTempJsonFile({ | |
text: actualQuery, | |
}); | |
const retrievalConfigFile = await this.writeTempJsonFile({ | |
vectorSearchConfiguration: { | |
numberOfResults: kbConfig.numberOfResults || 32, | |
}, | |
}); | |
// No /dev/stdout here! | |
const commandParts = [ | |
"aws bedrock-agent-runtime retrieve", | |
`--knowledge-base-id ${this.escapeShellArg( | |
kbConfig.knowledgeBaseId | |
)}`, | |
`--retrieval-query ${retrievalQueryFile}`, | |
`--retrieval-configuration ${retrievalConfigFile}`, | |
`--region ${this.escapeShellArg(this.region)}`, | |
"--profile bedrock", | |
]; | |
// If you have guardrails, uncomment and define them in the type: | |
/* | |
if (kbConfig.guardrailsId) { | |
const guardrailsFile = await this.writeTempJsonFile({ | |
guardrailsId: kbConfig.guardrailsId, | |
guardrailsVersion: kbConfig.guardrailsVersion | |
}); | |
commandParts.push(`--guardrails-configuration ${guardrailsFile}`); | |
} | |
*/ | |
const command = commandParts.join(" "); | |
console.log("[BedrockKB] Executing retrieve command:", command); | |
const [stdout, stderr] = await extras.ide.subprocess(command); | |
if (stderr) { | |
console.error( | |
`[BedrockKB] Error for KB ${kbConfig.knowledgeBaseId}:`, | |
stderr | |
); | |
return { kbId: kbConfig.knowledgeBaseId, results: [] }; | |
} | |
const response = JSON.parse(stdout); | |
const results = response.retrievalResults || []; | |
const chunks = results.map((r: any) => ({ | |
content: r.content?.text ?? "", | |
metadata: { | |
location: r.location || "Unknown location", | |
kbId: kbConfig.knowledgeBaseId, | |
}, | |
})); | |
// Rerank if we have content | |
if (chunks.length > 0) { | |
try { | |
const rerankScores = await this.rerankResults( | |
actualQuery, | |
chunks, | |
extras | |
); | |
return { | |
kbId: kbConfig.knowledgeBaseId, | |
results: results.map((r: any, idx: number) => ({ | |
...r, | |
score: rerankScores[idx] || 0, | |
})), | |
}; | |
} catch (err) { | |
console.error("[BedrockKB] Reranking error:", err); | |
return { kbId: kbConfig.knowledgeBaseId, results }; | |
} | |
} | |
return { kbId: kbConfig.knowledgeBaseId, results }; | |
} catch (err) { | |
console.error( | |
`[BedrockKB] Error querying KB ${kbConfig.knowledgeBaseId}:`, | |
err | |
); | |
return { kbId: kbConfig.knowledgeBaseId, results: [] }; | |
} | |
}) | |
); | |
// Flatten into context items | |
const contextItems = allResults.flatMap(({ kbId, results }) => | |
results | |
.filter((r: any) => r.content?.text) | |
.map((r: any) => ({ | |
name: `${kbId} - ${r.location || "Unknown location"}`, | |
description: `Relevance score: ${r.score || 0}`, | |
content: r.content?.text ?? "", | |
})) | |
); | |
console.log("[BedrockKB] Final context items:", contextItems.length); | |
return contextItems; | |
} catch (err) { | |
console.error("[BedrockKB] Error in getContextItems:", err); | |
return []; | |
} | |
}, | |
async rerankResults( | |
query: string, | |
chunks: any[], | |
extras: ContextProviderExtras | |
) { | |
try { | |
const payloadFile = await this.writeTempJsonFile({ | |
query, | |
documents: chunks.map((c) => c.content), | |
}); | |
// For bedrock-runtime invoke-model, you might also remove `/dev/stdout` | |
// if it complains about "Unknown options: /dev/stdout". | |
// But some versions require an outfile. If you see | |
// “the following arguments are required: outfile” then keep it. | |
const commandParts = [ | |
"aws bedrock-runtime invoke-model", | |
`--model-id cohere.rerank-v3-5:0`, | |
"--accept application/json", | |
"--content-type application/json", | |
`--region ${this.escapeShellArg(this.region)}`, | |
"--profile bedrock", | |
`--body ${payloadFile}`, | |
// If the CLI complains about "Unknown options: /dev/stdout", | |
// remove this next line. If it complains about "outfile is required", keep it. | |
"/dev/stdout", | |
]; | |
const command = commandParts.join(" "); | |
console.log("[BedrockKB] Reranking command:", command); | |
const [stdout, stderr] = await extras.ide.subprocess(command); | |
if (stderr) { | |
throw new Error(`Reranking error: ${stderr}`); | |
} | |
const response = JSON.parse(stdout); | |
// 'body' is typically base64-encoded JSON | |
const decoded = JSON.parse( | |
new TextDecoder().decode(Buffer.from(response.body)) | |
); | |
// Example: { results: [ { index: i, relevance_score: number }, ... ] } | |
return decoded.results | |
.sort((a: any, b: any) => a.index - b.index) | |
.map((r: any) => r.relevance_score); | |
} catch (err) { | |
console.error("[BedrockKB] Reranking error:", err); | |
throw err; | |
} | |
}, | |
}; | |
export function modifyConfig(config: Config): Config { | |
console.log("[BedrockKB] modifyConfig called, adding provider"); | |
if (!config.contextProviders) { | |
config.contextProviders = []; | |
} | |
config.contextProviders.push(BedrockKBContextProvider); | |
return config; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment