Skip to content

Instantly share code, notes, and snippets.

@jcollingj
Created May 17, 2024 13:17
Show Gist options
  • Save jcollingj/d887e046853313498d2acf67ddc2b594 to your computer and use it in GitHub Desktop.
Save jcollingj/d887e046853313498d2acf67ddc2b594 to your computer and use it in GitHub Desktop.
fastapi-mlx
import json
from mlx_lm.utils import load, generate_step
import mlx.core as mx
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from mlx_lm import load, generate
import asyncio
from typing import List
from starlette.responses import StreamingResponse # Add this import
from fastapi.responses import StreamingResponse
app = FastAPI()
# Enable CORS for all origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
class UserInput(BaseModel):
message: str
# Global variables for model and tokenizer
model, tokenizer = None, None
@app.on_event("startup")
async def load_model():
global model, tokenizer
# upload_repo = "mlx-community/Llama-3-8B-Instruct-1048k-4bit"
# upload_repo = "mlx-community/Mistral-7B-Instruct-v0.2-4bit"
upload_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
loop = asyncio.get_event_loop()
model, tokenizer = await loop.run_in_executor(None, load, upload_repo)
tokenizer.eos_token = "<|eot_id|>"
class Message(BaseModel):
role: str
content: str
class CreateChatCompletionRequest(BaseModel):
model: str
messages: List[dict]
max_tokens: int = 300
temperature: float = 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: List[str] = None
logprobs: int = None
user: str = None
api_key: str = None # Added API key field
apiKey: str = None
stop_words = ["<|im_start|>", "<|im_end|>", "<s>", "</s>"]
@app.post("/chat/completions")
async def create_chat_completion(
request: CreateChatCompletionRequest,
http_request: Request,
):
headers = http_request.headers
try:
print(request.messages)
prompt = tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stream = request.stream
print(f"Stream: {stream}")
if stream:
print("Running in stream")
async def event_generator():
try:
async for step in async_generate_steps(prompt, model, tokenizer):
json_step = json.dumps(
{"choices": [{"delta": {"content": step}}]}
)
print(f"Yielding step: {json_step}")
yield f"data: {json_step}\n\n"
await asyncio.sleep(0) # Yield control to the event loop
yield "data: [DONE]\n\n" # Indicate the end of the stream
except Exception as e:
print(f"Error in event_generator: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
raise
return StreamingResponse(event_generator(), media_type="text/event-stream")
response_content = await get_conversation_response(prompt)
response = {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": request.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response_content,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": len(request.messages),
"completion_tokens": len(response_content.split()),
"total_tokens": len(request.messages) + len(response_content.split()),
},
"api_key": request.api_key,
}
return response
except Exception as e:
print(f"Error in create_chat_completion: {e}")
raise HTTPException(status_code=500, detail=str(e))
async def async_generate_steps(prompt, the_model, tokenizer):
tokens = []
skip = 0
try:
for (token, prob), n in zip(
generate_step(mx.array(tokenizer.encode(prompt)), the_model, 1),
range(2000),
):
try:
if token == tokenizer.eos_token_id:
break
tokens.append(token) # Remove .item() here
text = tokenizer.decode(tokens)
trim = None
for sw in stop_words:
if text[-len(sw) :].lower() == sw:
return
else:
for i, _ in enumerate(sw, start=1):
if text[-i:].lower() == sw[:i]:
trim = -i
print(text)
yield text[skip:trim] or " " # Ensure something is yielded
skip = len(text)
except asyncio.CancelledError:
print("Client disconnected")
return
except Exception as e:
print(f"Error in token processing: {e}")
continue
except asyncio.CancelledError:
print("Client disconnected")
return
except Exception as e:
print(f"Error in async_generate_steps: {e}")
raise
async def get_conversation_response(messages):
print("Generating response...")
loop = asyncio.get_event_loop()
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
print(prompt)
tokenizer.eos_token = "<|eot_id|>"
try:
# Ensure the generate function is called in a way that doesn't block the event loop
response = await loop.run_in_executor(
None, # Uses the default executor (ThreadPoolExecutor)
generate, # The synchronous function to run
model,
tokenizer,
prompt, # Arguments for the generate function
1,
30000, # Additional arguments (verbose, max_tokens)
True,
)
except Exception as e:
print(f"Error generating response: {e}")
raise # Re-raise the exception or handle it as needed
return response
async def get_response_async(user_input):
print("Generating response...")
messages = [
{
"role": "system",
"content": "You are a helpful assistant collaborating with a knowledge worker",
},
{"role": "user", "content": user_input},
]
loop = asyncio.get_event_loop()
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
tokenizer.eos_token = "<|eot_id|>"
try:
# Ensure the generate function is called in a way that doesn't block the event loop
response = await loop.run_in_executor(
None, # Uses the default executor (ThreadPoolExecutor)
generate, # The synchronous function to run
model,
tokenizer,
prompt, # Arguments for the generate function
1,
30000, # Additional arguments (verbose, max_tokens)
True,
)
except Exception as e:
print(f"Error generating response: {e}")
raise # Re-raise the exception or handle it as needed
return response
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment