Last active
May 29, 2024 20:08
-
-
Save smellslikeml/363dad8105700d9a96f65a8bd438e2d4 to your computer and use it in GitHub Desktop.
Example using custom LLM with Fast API
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from fastapi import FastAPI, Request | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
model_name = "name of model" | |
prompt_template = """ ### Input: {} ### Response: {}""" | |
stop_token_ids = [0] | |
app = FastAPI() | |
class InputData(BaseModel): | |
prompt: str | |
class OutputData(BaseModel): | |
response: str | |
def load_quantized_model(model_name: str): | |
""":param model_name: Name or path of the model to be loaded. | |
:return: Loaded quantized model.""" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
device_map="auto", | |
# load_in_4bit=True, | |
) | |
return model | |
def initialize_tokenizer(model_name: str): | |
"""Initialize the tokenizer with the specified model_name. | |
:param model_name: Name or path of the model for tokenizer initialization. | |
:return: Initialized tokenizer.""" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.bos_token_id = 1 # Set beginning of sentence token id | |
return tokenizer | |
model = load_quantized_model(model_name) | |
tokenizer = initialize_tokenizer(model_name) | |
@app.post("/generate", response_model=OutputData) | |
def generate(request: Request, input_data: InputData): | |
inputs = tokenizer( | |
[ | |
prompt_template.format( | |
input_data.prompt, # Input | |
"", # Response - leave this blank for generation! | |
) | |
], | |
return_tensors="pt", | |
).to("cuda") | |
outputs = model.generate(**inputs, max_new_tokens=4096, use_cache=True) | |
response = tokenizer.batch_decode(outputs)[0].replace("\n", "") | |
return OutputData(response=response) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment