Created
May 17, 2024 13:17
-
-
Save jcollingj/d887e046853313498d2acf67ddc2b594 to your computer and use it in GitHub Desktop.
fastapi-mlx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from mlx_lm.utils import load, generate_step | |
import mlx.core as mx | |
from fastapi import FastAPI, HTTPException, Depends, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from mlx_lm import load, generate | |
import asyncio | |
from typing import List | |
from starlette.responses import StreamingResponse # Add this import | |
from fastapi.responses import StreamingResponse | |
app = FastAPI() | |
# Enable CORS for all origins | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
class UserInput(BaseModel): | |
message: str | |
# Global variables for model and tokenizer | |
model, tokenizer = None, None | |
@app.on_event("startup") | |
async def load_model(): | |
global model, tokenizer | |
# upload_repo = "mlx-community/Llama-3-8B-Instruct-1048k-4bit" | |
# upload_repo = "mlx-community/Mistral-7B-Instruct-v0.2-4bit" | |
upload_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit" | |
loop = asyncio.get_event_loop() | |
model, tokenizer = await loop.run_in_executor(None, load, upload_repo) | |
tokenizer.eos_token = "<|eot_id|>" | |
class Message(BaseModel): | |
role: str | |
content: str | |
class CreateChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[dict] | |
max_tokens: int = 300 | |
temperature: float = 1.0 | |
top_p: float = 1.0 | |
n: int = 1 | |
stream: bool = False | |
stop: List[str] = None | |
logprobs: int = None | |
user: str = None | |
api_key: str = None # Added API key field | |
apiKey: str = None | |
stop_words = ["<|im_start|>", "<|im_end|>", "<s>", "</s>"] | |
@app.post("/chat/completions") | |
async def create_chat_completion( | |
request: CreateChatCompletionRequest, | |
http_request: Request, | |
): | |
headers = http_request.headers | |
try: | |
print(request.messages) | |
prompt = tokenizer.apply_chat_template( | |
request.messages, tokenize=False, add_generation_prompt=True | |
) | |
stream = request.stream | |
print(f"Stream: {stream}") | |
if stream: | |
print("Running in stream") | |
async def event_generator(): | |
try: | |
async for step in async_generate_steps(prompt, model, tokenizer): | |
json_step = json.dumps( | |
{"choices": [{"delta": {"content": step}}]} | |
) | |
print(f"Yielding step: {json_step}") | |
yield f"data: {json_step}\n\n" | |
await asyncio.sleep(0) # Yield control to the event loop | |
yield "data: [DONE]\n\n" # Indicate the end of the stream | |
except Exception as e: | |
print(f"Error in event_generator: {e}") | |
yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
raise | |
return StreamingResponse(event_generator(), media_type="text/event-stream") | |
response_content = await get_conversation_response(prompt) | |
response = { | |
"id": "chatcmpl-123", | |
"object": "chat.completion", | |
"created": 1677652288, | |
"model": request.model, | |
"choices": [ | |
{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": response_content, | |
}, | |
"finish_reason": "stop", | |
} | |
], | |
"usage": { | |
"prompt_tokens": len(request.messages), | |
"completion_tokens": len(response_content.split()), | |
"total_tokens": len(request.messages) + len(response_content.split()), | |
}, | |
"api_key": request.api_key, | |
} | |
return response | |
except Exception as e: | |
print(f"Error in create_chat_completion: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def async_generate_steps(prompt, the_model, tokenizer): | |
tokens = [] | |
skip = 0 | |
try: | |
for (token, prob), n in zip( | |
generate_step(mx.array(tokenizer.encode(prompt)), the_model, 1), | |
range(2000), | |
): | |
try: | |
if token == tokenizer.eos_token_id: | |
break | |
tokens.append(token) # Remove .item() here | |
text = tokenizer.decode(tokens) | |
trim = None | |
for sw in stop_words: | |
if text[-len(sw) :].lower() == sw: | |
return | |
else: | |
for i, _ in enumerate(sw, start=1): | |
if text[-i:].lower() == sw[:i]: | |
trim = -i | |
print(text) | |
yield text[skip:trim] or " " # Ensure something is yielded | |
skip = len(text) | |
except asyncio.CancelledError: | |
print("Client disconnected") | |
return | |
except Exception as e: | |
print(f"Error in token processing: {e}") | |
continue | |
except asyncio.CancelledError: | |
print("Client disconnected") | |
return | |
except Exception as e: | |
print(f"Error in async_generate_steps: {e}") | |
raise | |
async def get_conversation_response(messages): | |
print("Generating response...") | |
loop = asyncio.get_event_loop() | |
prompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
print(prompt) | |
tokenizer.eos_token = "<|eot_id|>" | |
try: | |
# Ensure the generate function is called in a way that doesn't block the event loop | |
response = await loop.run_in_executor( | |
None, # Uses the default executor (ThreadPoolExecutor) | |
generate, # The synchronous function to run | |
model, | |
tokenizer, | |
prompt, # Arguments for the generate function | |
1, | |
30000, # Additional arguments (verbose, max_tokens) | |
True, | |
) | |
except Exception as e: | |
print(f"Error generating response: {e}") | |
raise # Re-raise the exception or handle it as needed | |
return response | |
async def get_response_async(user_input): | |
print("Generating response...") | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are a helpful assistant collaborating with a knowledge worker", | |
}, | |
{"role": "user", "content": user_input}, | |
] | |
loop = asyncio.get_event_loop() | |
prompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
tokenizer.eos_token = "<|eot_id|>" | |
try: | |
# Ensure the generate function is called in a way that doesn't block the event loop | |
response = await loop.run_in_executor( | |
None, # Uses the default executor (ThreadPoolExecutor) | |
generate, # The synchronous function to run | |
model, | |
tokenizer, | |
prompt, # Arguments for the generate function | |
1, | |
30000, # Additional arguments (verbose, max_tokens) | |
True, | |
) | |
except Exception as e: | |
print(f"Error generating response: {e}") | |
raise # Re-raise the exception or handle it as needed | |
return response |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment