Created
November 25, 2023 11:01
-
-
Save taishikato/2f53dada6ea3339ce60b0b55a672dd1b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { | |
Message, | |
StreamingTextResponse, | |
Message as VercelChatMessage, | |
} from "ai"; | |
import { SupabaseVectorStore } from "langchain/vectorstores/supabase"; | |
import { OpenAIEmbeddings } from "langchain/embeddings/openai"; | |
import { supabaseAdmin } from "../../../../utils/supabaseAdminClient"; | |
import { PromptTemplate } from "langchain/prompts"; | |
import { | |
BytesOutputParser, | |
StringOutputParser, | |
} from "langchain/schema/output_parser"; | |
import { ChatOpenAI } from "langchain/chat_models/openai"; | |
import { RunnableSequence } from "langchain/schema/runnable"; | |
import { type IterableReadableStream } from "langchain/dist/util/stream"; | |
import { createSupabaseServerClinet } from "../../../../utils/createSupabaseServerClinet"; | |
import { NextResponse } from "next/server"; | |
import { type SupabaseClient } from "@supabase/supabase-js"; | |
export const runtime = "edge"; | |
const combineDocumentsFn = (docs: Document[], separator = "\n\n") => { | |
// @ts-ignore | |
const serializedDocs = docs.map((doc) => doc.pageContent); | |
return serializedDocs.join(separator); | |
}; | |
const formatVercelMessages = ( | |
chatHistory: VercelChatMessage[], | |
limitConversation: number | |
) => { | |
console.log({ chatHistory }); | |
console.log({ limitConversation }); | |
for (let i = 0; i < chatHistory.length; i += 2) {} | |
const formattedDialogueTurns = chatHistory.map((message) => { | |
if (message.role === "user") { | |
return `Human: ${message.content}`; | |
} else if (message.role === "assistant") { | |
return `Assistant: ${message.content}`; | |
} else { | |
return `${message.role}: ${message.content}`; | |
} | |
}); | |
const slicedDialog = formattedDialogueTurns.slice(-(limitConversation * 2)); | |
return slicedDialog.join("\n"); | |
}; | |
const CONDENSE_QUESTION_TEMPLATE = `Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. | |
<chat_history> | |
{chat_history} | |
</chat_history> | |
Follow Up Input: {question} | |
Standalone question:`; | |
const condenseQuestionPrompt = PromptTemplate.fromTemplate( | |
CONDENSE_QUESTION_TEMPLATE | |
); | |
const ANSWER_TEMPLATE = `You are a helpfull assistant named Penelope for researcher and writers, and must answer all questions like in a positive tone. | |
Be aware to finish your answer within 500 tokens. | |
Answer the question based only on the following context and chat history: | |
<context> | |
{context} | |
</context> | |
<chat_history> | |
{chat_history} | |
</chat_history> | |
Question: {question} | |
`; | |
const answerPrompt = PromptTemplate.fromTemplate(ANSWER_TEMPLATE); | |
async function* handleStreamEnd({ | |
stream, | |
supabase, | |
supabaseAdmin, | |
resourceId, | |
currentMessage, | |
}: { | |
stream: IterableReadableStream<Uint8Array>; | |
supabase: SupabaseClient; | |
supabaseAdmin: SupabaseClient; | |
resourceId: string; | |
currentMessage: Record<string, string>; | |
}) { | |
const decoder = new TextDecoder("utf-8"); | |
let responseText = ""; | |
for await (const chunk of stream) { | |
const chunkAsString = decoder.decode(chunk); | |
responseText += chunkAsString; | |
yield chunk; | |
} | |
console.log("Stream ended"); | |
console.log("save chats on the database"); | |
const { | |
data: { user }, | |
} = await supabase.auth.getUser(); | |
if (!user || !resourceId) return; | |
await Promise.all([ | |
supabaseAdmin.from("chats").insert({ | |
role: currentMessage.role, | |
content: currentMessage.content, | |
user_id: user.id, | |
resource_id: resourceId, | |
created_at: currentMessage.created_at, | |
}), | |
supabaseAdmin.from("chats").insert({ | |
role: "assistant", | |
content: responseText, | |
user_id: user.id, | |
resource_id: resourceId, | |
}), | |
]); | |
} | |
export async function POST(req: Request) { | |
const body = await req.json(); | |
const messages = body.messages ?? []; | |
const resourceId = body.resourceId as null | string; | |
const previousMessages = messages.slice(0, -1) as Message[]; | |
const currentMessage = messages[messages.length - 1]; | |
const currentMessageContent = currentMessage.content; | |
const supabase = createSupabaseServerClinet(); | |
currentMessage.created_at = new Date(); | |
if (!resourceId || resourceId === null) | |
return NextResponse.json( | |
{ result: "error", message: "resourceId is not in the request body" }, | |
{ status: 400 } | |
); | |
const { data: resourceFromDatabase } = await supabase | |
.from("resources") | |
.select("youtube_id") | |
.eq("id", resourceId); | |
if (!resourceFromDatabase) | |
return NextResponse.json( | |
{ result: "error", message: "the requested resource doesn't exist" }, | |
{ status: 400 } | |
); | |
const model = new ChatOpenAI({ | |
modelName: "gpt-3.5-turbo-1106", | |
// modelName: "gpt-4-1106-preview", | |
temperature: 0, | |
maxTokens: 500, | |
}); | |
const vectorstore = new SupabaseVectorStore(new OpenAIEmbeddings(), { | |
client: supabaseAdmin, | |
tableName: "vectors", | |
queryName: "match_vectors", | |
filter: { youtube_id: resourceFromDatabase[0].youtube_id }, | |
}); | |
const standaloneQuestionChain = RunnableSequence.from([ | |
condenseQuestionPrompt, | |
model, | |
new StringOutputParser(), | |
]); | |
let resolveWithDocuments: (value: Document[]) => void; | |
const documentPromise = new Promise<Document[]>((resolve) => { | |
resolveWithDocuments = resolve; | |
}); | |
const retriever = vectorstore.asRetriever({ | |
callbacks: [ | |
{ | |
handleRetrieverEnd(documents) { | |
// @ts-ignore | |
resolveWithDocuments(documents); | |
}, | |
}, | |
], | |
}); | |
// @ts-ignore | |
const retrievalChain = retriever.pipe(combineDocumentsFn); | |
const answerChain = RunnableSequence.from([ | |
{ | |
context: RunnableSequence.from([ | |
(input) => input.question, | |
retrievalChain, | |
]), | |
chat_history: (input) => input.chat_history, | |
question: (input) => input.question, | |
}, | |
answerPrompt, | |
model, | |
]); | |
const conversationalRetrievalQAChain = RunnableSequence.from([ | |
{ | |
question: standaloneQuestionChain, | |
chat_history: (input) => input.chat_history, | |
}, | |
answerChain, | |
new BytesOutputParser(), | |
]); | |
const stream = await conversationalRetrievalQAChain.stream({ | |
question: currentMessageContent, | |
chat_history: formatVercelMessages(previousMessages, 3), | |
}); | |
const handledStream = handleStreamEnd({ | |
stream, | |
supabase, | |
supabaseAdmin, | |
resourceId, | |
currentMessage, | |
}); | |
const documents = await documentPromise; | |
const serializedSources = Buffer.from( | |
JSON.stringify( | |
documents.map((doc) => { | |
return { | |
// @ts-ignore | |
pageContent: doc.pageContent.slice(0, 50) + "...", | |
// @ts-ignore | |
metadata: doc.metadata, | |
}; | |
}) | |
) | |
).toString("base64"); | |
// @ts-ignore | |
const response = new StreamingTextResponse(handledStream, { | |
headers: { | |
"x-message-index": (previousMessages.length + 1).toString(), | |
"x-sources": serializedSources, | |
}, | |
}); | |
return response; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment