Skip to content

Instantly share code, notes, and snippets.

@sebastianschramm
Created August 28, 2024 08:45
Show Gist options
  • Save sebastianschramm/fec409d73879b584d88bdf5ec115606e to your computer and use it in GitHub Desktop.
Save sebastianschramm/fec409d73879b584d88bdf5ec115606e to your computer and use it in GitHub Desktop.
LitServe API for toxicity classifier
import torch
from litserve import LitAPI, LitServer
from pydantic import BaseModel, conint
from pydantic_settings import BaseSettings
from transformers import AutoModelForSequenceClassification, AutoTokenizer
class ToxicitySettings(BaseSettings):
model_id: str = "s-nlp/roberta_toxicity_classifier"
port: conint(ge=1024, le=65535) = 8000 # type: ignore
device: str = "cpu"
endpoint_path: str = "/toxicity"
class InputObject(BaseModel):
text: str
class Response(BaseModel):
label: str
probability: float
class ToxicityLitAPI(LitAPI):
def __init__(self, model_id: str) -> None:
super().__init__()
self.model_id = model_id
def setup(self, device):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
self.model.to(device)
self.model.eval()
def decode_request(self, request: InputObject):
inputs = self.tokenizer(
request.text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
)
return inputs
def predict(self, inputs):
with torch.no_grad():
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
return outputs.logits
def encode_response(self, logits) -> Response:
probabilities = torch.nn.functional.softmax(logits, dim=-1)
max_ind = torch.argmax(probabilities).item()
probability = probabilities[0][max_ind].item()
label = self.model.config.id2label[max_ind]
response = Response(label=label, probability=probability)
return response
def run_server(settings):
api = ToxicityLitAPI(model_id=settings.model_id)
server = LitServer(
api, accelerator=settings.device, api_path=settings.endpoint_path
)
server.run(port=settings.port, generate_client_file=False)
if __name__ == "__main__":
settings = ToxicitySettings()
run_server(settings)
@sebastianschramm
Copy link
Author

To query that API from python:

import requests

response = requests.post(
    "http://0.0.0.0:8000/toxicity", json={"text": "Have a great day!"}
)
print(response.json())

>> {"label": "neutral", "probability": 0.9999544620513916}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment