Created
November 17, 2023 05:01
-
-
Save aleksandr-smechov/f9a223ff62521c1644469b3505308d45 to your computer and use it in GitHub Desktop.
Server-side distil-whisper streaming code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from typing import List | |
import numpy as np | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
vad_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", | |
model="silero_vad", | |
force_reload=True, | |
onnx=True) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model_id = "distil-whisper/distil-medium.en" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=False, use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
transcriber = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=64, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
app = FastAPI() | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections: List[WebSocket] = [] | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections.append(websocket) | |
async def disconnect(self, websocket: WebSocket): | |
self.active_connections.remove(websocket) | |
async def send_text(self, message: str, websocket: WebSocket): | |
await websocket.send_text(message) | |
async def broadcast(self, message: str): | |
await asyncio.gather(*(connection.send_text(message) for connection in self.active_connections)) | |
manager = ConnectionManager() | |
@app.websocket("/ws/transcribe") | |
async def websocket_transcribe(websocket: WebSocket): | |
await manager.connect(websocket) | |
try: | |
async def audio_generator(): | |
while True: | |
audio_data = await websocket.receive_bytes() | |
np_data = np.frombuffer(audio_data, dtype=np.float32) | |
speech_prob = vad_model(torch.tensor(np_data), 16000).item() | |
if speech_prob > 0.2: | |
yield np_data | |
audio_gen = audio_generator() | |
async for audio_chunk in audio_gen: | |
transcription = transcriber(audio_chunk, generate_kwargs={"max_new_tokens": 128}) | |
await manager.send_text(transcription["text"], websocket) | |
except WebSocketDisconnect: | |
await manager.disconnect(websocket) | |
except Exception as e: | |
print(e) | |
await websocket.close(code=1001) | |
await manager.disconnect(websocket) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment