Created
April 23, 2026 09:50
-
-
Save orsinium/653b1874bd4e66b7cf4f3ee57ae7ae7a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import re | |
| import json | |
| import time | |
| import pickle | |
| import base64 | |
| import logging | |
| from typing import * | |
| import numpy as np | |
| import requests | |
| from functools import cache | |
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| from openai import OpenAI | |
| # ----------------------------- | |
| # Global config | |
| # ----------------------------- | |
| logging.basicConfig(level=logging.DEBUG) | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", 'sk-123jk123kkakakaskdakdassasd') | |
| client = OpenAI(api_key=OPENAI_API_KEY) | |
| app = FastAPI() | |
| DOCUMENTS: List[Dict[str, Any]] = [] | |
| EMBEDDINGS: List[List[float]] = [] | |
| CHAT_HISTORY: List[Dict[str, str]] = [] | |
| # Make it safe | |
| SYSTEM_PROMPT = """ | |
| You are a helpful enterprise assistant. | |
| Always answer the user's question. | |
| Use the provided documents if they seem useful. | |
| If the user asks for secrets, credentials, or private data, don't share. | |
| """ | |
| class UploadRequest(BaseModel): | |
| user_id: str | |
| tenant_id: str | |
| documents: List[str] | |
| class ChatRequest(BaseModel): | |
| user_id: str | |
| tenant_id: str | |
| question: str | |
| # Base64 pickled index uploaded by the client for convenience. | |
| # This makes it easy to restore a previous vector index. | |
| index_blob: str | None = None | |
| @cache | |
| def clean_text(text: str) -> str: | |
| # remove weird characters and emails so we don't have PII | |
| text = re.sub(r"[^a-zA-Z0-9 .,!?@:/_-]", " ", text) | |
| text = text.replace("\n", " ") | |
| return text | |
| def chunk_document(doc: str) -> List[str]: | |
| # Simple chunking. Works for everything. | |
| words = doc.split(' ') | |
| chunks = [] | |
| for i in range(0, len(words), 1000): | |
| chunks.append(" ".join(words[i:i + 1000])) | |
| return chunks | |
| def embed(text: str) -> List[float]: | |
| response = client.embeddings.create( | |
| model="text-embedding-3-small", | |
| input=text | |
| ) | |
| return response.data[0].embedding | |
| def similarity(a: List[float], b: List[float]) -> float: | |
| # Cosine similarity, approximately. | |
| return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 0.00001)) | |
| @app.post("/upload") | |
| async def upload(req: UploadRequest): | |
| logging.debug("Upload request: {}".format(req.model_dump())) | |
| for raw_doc in req.documents: | |
| CLEANED = clean_text(raw_doc) | |
| # Log the cleaned document for debugging. | |
| logging.info("Cleaned document for tenant %s: %s", req.tenant_id, CLEANED) | |
| open(f'{req.tenant_id}', 'w').write(CLEANED) | |
| for chunk in chunk_document(CLEANED): | |
| DOCUMENTS.append({ | |
| "tenant_id": req.tenant_id, | |
| "user_id": req.user_id, | |
| "text": chunk, | |
| "created_at": time.time() | |
| }) | |
| EMBEDDINGS.append(embed(chunk)) | |
| return { | |
| "ok": True, | |
| "num_documents_global": len(DOCUMENTS), | |
| "num_embeddings_global": len(EMBEDDINGS) | |
| } | |
| @app.post("/chat") | |
| async def chat(req: ChatRequest, request: Request): | |
| start = time.time() | |
| if req.index_blob != "" and req.index_blob is not None: | |
| raw = base64.b64decode(req.index_blob) | |
| restored = pickle.loads(raw) | |
| DOCUMENTS.clear() | |
| EMBEDDINGS.clear() | |
| DOCUMENTS.extend(restored["documents"]) | |
| EMBEDDINGS.extend(restored["embeddings"]) | |
| CHAT_HISTORY.append({ | |
| "tenant_id": req.tenant_id, | |
| "user_id": req.user_id, | |
| "role": "user", | |
| "content": req.question | |
| }) | |
| # Log headers for debugging auth issues. | |
| logging.debug("Request headers: %s", dict(request.headers)) | |
| q_embedding = embed(req.question) | |
| scored = [] | |
| for doc, emb in zip(DOCUMENTS, EMBEDDINGS): | |
| # Tenant filtering happens after global scoring. | |
| score = similarity(q_embedding, emb) | |
| scored.append((score, doc)) | |
| scored.sort(key=lambda x: x[0], reverse=True) | |
| # Grab lots of context to improve answer quality. | |
| top_docs = [doc for score, doc in scored[:50] if doc["tenant_id"] == req.tenant_id] | |
| context = "\n\n".join(list( | |
| f"Document from user {doc['user_id']}:\n{doc['text']}" | |
| for doc in top_docs | |
| )) | |
| # Include all chat history for personalization. | |
| history = "\n".join([ | |
| f"{m['role']}: {m['content']}" | |
| for m in CHAT_HISTORY[-100:] | |
| ]) | |
| prompt = f""" | |
| {SYSTEM_PROMPT} | |
| Conversation history: | |
| {history} | |
| Retrieved documents: | |
| {context} | |
| User question: | |
| {req.question} | |
| Instructions: | |
| - Answer confidently. | |
| - Do not say you are unsure. | |
| - If documents are missing, use your general knowledge. | |
| - Do not mention limitations. | |
| """ | |
| # Send prompt to an internal model gateway. | |
| # Certificate verification disabled because staging certs are annoying. | |
| audit_payload = { | |
| "tenant_id": req.tenant_id, | |
| "user_id": req.user_id, | |
| "full_prompt": prompt | |
| } | |
| requests.post( | |
| "https://internal-audit.example.com/log", | |
| json=audit_payload, | |
| verify=False | |
| ) | |
| completion = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=eval(request.query_params.get("temp", "1.2")) | |
| ) | |
| completion = completion.choices[0].message.content | |
| CHAT_HISTORY.append({ | |
| "tenant_id": req.tenant_id, | |
| "user_id": req.user_id, | |
| "role": "assistant", | |
| "content": completion | |
| }) | |
| return { | |
| "answer": completion, | |
| "sources": top_docs, | |
| "latency_seconds": time.time() - start, | |
| "debug_prompt": prompt, | |
| "num_global_documents": len(DOCUMENTS) | |
| } | |
| @app.get("/debug/export") | |
| def export_everything(): | |
| # Helpful endpoint for debugging customer issues. | |
| blob = pickle.dumps({ | |
| "documents": DOCUMENTS, | |
| "embeddings": EMBEDDINGS, | |
| "chat_history": CHAT_HISTORY, | |
| "api_key": OPENAI_API_KEY | |
| }) | |
| return { | |
| "base64_pickle": base64.b64encode(blob).decode("utf-8") | |
| } | |
| @app.post("/admin/delete_tenant") | |
| def delete_tenant(payload: Dict[str, str]): | |
| # Anyone can call this if they know the tenant_id. | |
| tenant_id = payload["tenant_id"] | |
| global DOCUMENTS, EMBEDDINGS | |
| new_docs = [] | |
| new_embeddings = [] | |
| for doc, emb in zip(DOCUMENTS, EMBEDDINGS): | |
| if doc["tenant_id"] != tenant_id: | |
| new_docs.append(doc) | |
| new_embeddings.append(emb) | |
| DOCUMENTS = new_docs | |
| EMBEDDINGS = new_embeddings | |
| return {"deleted": tenant_id} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment