Last active
July 17, 2024 18:41
-
-
Save eddking/330ca6f304b2f7293cf4031af123a328 to your computer and use it in GitHub Desktop.
Streaming Function Calling with OpenAI assistants API
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Assistant } from "openai/resources/beta/assistants/assistants"; | |
import { Message } from "openai/resources/beta/threads/messages/messages"; | |
import { RequiredActionFunctionToolCall } from "openai/resources/beta/threads/runs/runs"; | |
import { Thread } from "openai/resources/beta/threads/threads"; | |
import { searchInput } from "src/handlers/search"; | |
import { adminUpdateAiThreadExtra } from "src/operations/ai_thread"; | |
import { openai } from "../lib/openai"; | |
import { | |
getMessageContent, | |
parsePlain, | |
recordAssistantMessage, | |
updateAssistantMessage, | |
} from "./message"; | |
import { InternalThread } from "./threads"; | |
import { | |
handleLinkToolCall, | |
handleSearchToolCall, | |
linkInput, | |
ToolCallContext, | |
} from "./tools"; | |
import { IdMapper } from "./idMapper"; | |
import { InternalAssistant } from "./assistant"; | |
import { AiRunEvent } from "models"; | |
export const runThread = async ( | |
thread: Thread, | |
assistant: Assistant, | |
internalAssistant: InternalAssistant, | |
internalThread: InternalThread, | |
toolCallContext: ToolCallContext, | |
recordEvent: (event: AiRunEvent) => void | |
) => { | |
const idMapper = new IdMapper(internalThread.extra.id_map); | |
// Something thats gonna resolve once the message is created | |
const inProgressMessages: Record< | |
string, | |
Promise<{ id: string } | null | undefined> | |
> = {}; | |
const updateMessage = async (message: Message) => { | |
try { | |
const result = await inProgressMessages[message.id]; | |
if (!result) { | |
console.log("No result for message, skipping: ", message.id); | |
return; | |
} | |
await updateAssistantMessage(message, idMapper, result.id); | |
} catch (e) { | |
console.error("Error updating message: ", e); | |
} | |
}; | |
const currentTime = () => new Date().toISOString(); | |
const executeToolCall = async (toolCall: RequiredActionFunctionToolCall) => { | |
const name = toolCall.function.name; | |
const args = JSON.parse(toolCall.function.arguments); | |
switch (name) { | |
case "search": | |
const searchArgs = searchInput.parse(args); | |
const results = await handleSearchToolCall( | |
searchArgs, | |
toolCallContext, | |
idMapper | |
); | |
return JSON.stringify(results); | |
case "link_record": | |
const linkArgs = linkInput.parse(args); | |
return await handleLinkToolCall(linkArgs, toolCallContext, idMapper); | |
default: | |
throw new Error("Unknown tool call: " + name); | |
} | |
}; | |
try { | |
let stream = openai.beta.threads.runs.createAndStream(thread.id, { | |
assistant_id: assistant.id, | |
}); | |
let done = false; | |
while (!done) { | |
stream.on("messageCreated", async (message) => { | |
const currentRunId = stream.currentRun()!.id; | |
const createPromise = recordAssistantMessage( | |
message, | |
idMapper, | |
internalThread.id, | |
internalAssistant.id, | |
currentRunId, | |
toolCallContext.orgId | |
); | |
inProgressMessages[message.id] = createPromise; | |
const currentContent = getMessageContent(message); | |
const structured = await parsePlain(currentContent, idMapper); | |
const internalMessage = await createPromise; | |
recordEvent({ | |
type: "messageCreated", | |
content: currentContent, | |
structured: structured, | |
internal_id: internalMessage?.id!, | |
external_id: message.id, | |
role: message.role, | |
created_at: new Date(message.created_at).toISOString(), | |
}); | |
}); | |
stream.on("messageDelta", async (_messageDelta, snapshot) => { | |
const currentContent = getMessageContent(snapshot); | |
const structured = await parsePlain(currentContent, idMapper); | |
const createResult = await inProgressMessages[snapshot.id]; | |
recordEvent({ | |
type: "messageDelta", | |
content: currentContent, | |
structured: structured, | |
external_id: snapshot.id, | |
internal_id: createResult?.id!, | |
created_at: new Date(snapshot.created_at).toISOString(), | |
role: snapshot.role, | |
}); | |
}); | |
stream.on("messageDone", async (message) => { | |
const finalContent = getMessageContent(message); | |
const structured = await parsePlain(finalContent, idMapper); | |
const createResult = await inProgressMessages[message.id]; | |
updateMessage(message); // Update the internal message with the final content | |
recordEvent({ | |
type: "messageDone", | |
content: finalContent, | |
structured: structured, | |
external_id: message.id, | |
internal_id: createResult?.id!, | |
created_at: new Date(message.created_at).toISOString(), | |
role: message.role, | |
}); | |
}); | |
stream.on("event", (event) => { | |
console.log("Event: ", event.event); | |
let internalEvent: object = { event: event.event }; | |
switch (event.event) { | |
// Dont log these | |
case "thread.run.step.delta": | |
case "thread.message.created": | |
case "thread.message.in_progress": | |
case "thread.message.delta": | |
case "thread.message.completed": | |
case "thread.message.incomplete": | |
case "thread.run.queued": | |
case "thread.run.in_progress": | |
case "thread.run.cancelling": | |
case "thread.run.step.created": | |
case "thread.run.step.in_progress": | |
case "thread.run.step.completed": | |
case "thread.run.step.failed": | |
case "thread.run.step.cancelled": | |
case "thread.run.step.expired": | |
case "thread.created": | |
return; | |
case "thread.run.requires_action": // Logged separately below | |
return; | |
// Log these for posterity | |
case "thread.run.created": | |
recordEvent({ | |
type: "runCreated", | |
created_at: new Date(event.data.created_at).toISOString(), | |
}); | |
break; | |
case "thread.run.completed": | |
recordEvent({ type: "runCompleted", created_at: currentTime() }); | |
break; | |
case "thread.run.failed": | |
recordEvent({ type: "runFailed", created_at: currentTime() }); | |
break; | |
case "thread.run.cancelled": | |
recordEvent({ type: "runCancelled", created_at: currentTime() }); | |
break; | |
case "thread.run.expired": | |
recordEvent({ type: "runExpired", created_at: currentTime() }); | |
break; | |
case "error": // Logged after stream done | |
recordEvent({ | |
type: "runError", | |
error: event.data, | |
created_at: currentTime(), | |
}); | |
break; | |
default: | |
const _exhaustiveCheck: never = event; | |
} | |
}); | |
// I was planning to start executing tool calls as they stream in, storing the result promises in a map | |
// Until the requires_action event comes in, then I would submit the results of all the tool calls | |
// But it seems like this isnt easily supported at the moment. the data is there internally somewhere | |
// but 'toolCallDone' event seems to give empty '' arguments for the tool call | |
await stream.done(); | |
const currentRun = stream.currentRun(); | |
if (!currentRun) { | |
recordEvent({ | |
type: "runError", | |
error: "No Current Run", | |
created_at: currentTime(), | |
}); | |
return; | |
} | |
const lastError = currentRun.last_error; | |
if (lastError) { | |
recordEvent({ | |
type: "runError", | |
error: lastError, | |
created_at: currentTime(), | |
}); | |
return; | |
} | |
if (!currentRun.required_action) { | |
// This happens at the end of every successful run | |
recordEvent({ type: "noActionRequired", created_at: currentTime() }); | |
return; | |
} | |
// The assumption at the moment is that the only required action is to submit tool outputs | |
// This may change in the future | |
if (!currentRun.required_action.submit_tool_outputs) { | |
recordEvent({ | |
type: "runError", | |
error: "No tool outputs required", | |
created_at: currentTime(), | |
}); | |
return; | |
} | |
const toolCalls = | |
currentRun.required_action.submit_tool_outputs.tool_calls; | |
recordEvent({ | |
type: "requiredAction", | |
toolCalls, | |
created_at: currentTime(), | |
}); | |
const allToolResults = await Promise.all( | |
toolCalls.map(async (toolCall) => { | |
const toolCallId = toolCall.id; | |
try { | |
const result = await executeToolCall(toolCall); | |
console.log("Tool call result: ", toolCallId, result); | |
return { tool_call_id: toolCallId, output: result }; | |
} catch (e: any) { | |
console.error("Error executing tool call: ", e); | |
return { | |
tool_call_id: toolCallId, | |
output: "Error: " + (e.message || e), | |
}; | |
} | |
}) | |
); | |
recordEvent({ | |
type: "toolCallResults", | |
results: allToolResults, | |
created_at: currentTime(), | |
}); | |
// Set up the next stream and loop back up to the top | |
stream = openai.beta.threads.runs.submitToolOutputsStream( | |
thread.id, | |
stream.currentRun()!.id, | |
{ | |
stream: true, | |
tool_outputs: allToolResults, | |
} | |
); | |
} | |
} finally { | |
await adminUpdateAiThreadExtra( | |
{ id_map: idMapper.getState() }, | |
internalThread.id | |
); | |
} | |
}; | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { | |
AssistantTool, | |
FunctionTool, | |
} from "openai/resources/beta/assistants/assistants"; | |
import { zodToJsonSchema } from "zod-to-json-schema"; | |
import { searchInput, performSearch } from "src/handlers/search"; | |
import { z } from "zod"; | |
import { Language_Enum } from "src/generated/gql/graphql"; | |
import { IdMapper } from "./idMapper"; | |
import { isMissing } from "@util"; | |
export interface ToolCallContext { | |
orgId: string; | |
userId: string; | |
orgLanguage: Language_Enum; | |
} | |
// Dont allow the AI to set these params, they are more for UI specific things | |
const searchToolInput = searchInput.omit({ | |
includeFacets: true, | |
typeahead: true, | |
filter: true, // Filters are complicated, maybe we can test later | |
tags: true, | |
}); | |
export const searchSchema = zodToJsonSchema(searchToolInput); | |
export const searchTool: FunctionTool = { | |
type: "function", | |
function: { | |
name: "search", | |
description: | |
"Search everything within the organization. If a search result has a target_id, " + | |
" and this differs from it's own formatted_id, it means it is embedded in another page. ", | |
parameters: searchSchema, | |
}, | |
}; | |
export const handleSearchToolCall = async ( | |
input: z.infer<typeof searchToolInput>, | |
context: ToolCallContext, | |
idMapper: IdMapper | |
) => { | |
const { orgId, userId, orgLanguage } = context; | |
const searchContext = { | |
orgId, | |
userId, | |
orgLanguage, | |
admin: false, | |
ai: true, | |
}; | |
const results = await performSearch(input, searchContext); | |
return results.results.map((x) => processSearchResult(x, idMapper)); | |
}; | |
const processSearchResult = ( | |
result: any, | |
idMapper: IdMapper | |
): { document: unknown; target_id: string | null | undefined } => { | |
const document = result.document; | |
// Map ids into shortened ids | |
if (document["formatted_id"]) { | |
document["formatted_id"] = idMapper.get(document["formatted_id"]); | |
} | |
if (document["target_id"]) { | |
document["target_id"] = idMapper.get(document["target_id"]); | |
} | |
if (document["org_id"]) { | |
document["org_id"] = idMapper.get(document["org_id"]); | |
} | |
if (document["created_by"]) { | |
document["created_by"] = idMapper.get(document["created_by"]); | |
} | |
if (document["updated_by"]) { | |
document["updated_by"] = idMapper.get(document["updated_by"]); | |
} | |
if (document["id"]) { | |
document["id"] = idMapper.get(document["id"]); | |
} | |
if (document["created_at"]) { | |
delete document["created_at"]; | |
} | |
if (document["updated_at"]) { | |
delete document["updated_at"]; | |
} | |
if ("deleted_at" in document) { | |
// search results wont return deleted records anyway, no point paying for the tokens | |
delete document["deleted_at"]; | |
} | |
// Sometimes the target_id is not in the document, but it is in the result | |
// if it has been mapped from a different column during indexing. e.g. it comes from an association | |
return { | |
...document, | |
target_id: | |
document["target_id"] || | |
(result.target_id ? idMapper.get(result.target_id) : undefined), | |
}; | |
}; | |
export const linkInput = z.object({ | |
formatted_id: z | |
.string() | |
.describe("The formatted_id or target_id of the record"), | |
}); | |
export const linkSchema = zodToJsonSchema(linkInput); | |
export const linkTool: FunctionTool = { | |
type: "function", | |
function: { | |
name: "link_record", | |
description: | |
"Returns a url that you can use to display a record to the user, you must provide a formatted_id" + | |
"The returned url can be turned into a link with markdown syntax e.g. [testing](https://example.com).", | |
parameters: linkSchema, | |
}, | |
}; | |
export const RECORD_LINK_PREFIX = "https://nascent.com/link/"; | |
export const handleLinkToolCall = async ( | |
input: z.infer<typeof linkInput>, | |
context: ToolCallContext, | |
idMapper: IdMapper | |
) => { | |
const actual = idMapper.getInverse(input.formatted_id); | |
if (isMissing(actual) || !actual.includes(":")) { | |
throw new Error( | |
"Invalid id. it must be from a formatted_id or target_id field" | |
); | |
} | |
return RECORD_LINK_PREFIX + input.formatted_id; | |
}; | |
export const tools: AssistantTool[] = [searchTool, linkTool]; | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment