Created
November 5, 2024 13:55
-
-
Save lunks/fcceaeb13e57dc4e1711003c1907b809 to your computer and use it in GitHub Desktop.
Switching between models with NX/Bumblebee
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule App.ModelManager do | |
require Logger | |
use GenServer | |
def start_link(_) do | |
GenServer.start_link(__MODULE__, %{}, name: __MODULE__) | |
end | |
def init(_) do | |
state = %{current_model: nil, model_name: nil} | |
Logger.info("Starting ModelManager...") | |
{:ok, state} | |
end | |
def load_model(model_name) do | |
GenServer.call(__MODULE__, {:load_model, model_name}, :infinity) | |
end | |
def handle_call({:load_model, model_name}, _from, %{current_model: nil} = state) do | |
{:ok, pid} = start_model(model_name) | |
{:reply, :ok, %{state | current_model: pid, model_name: model_name}} | |
end | |
def handle_call( | |
{:load_model, model_name}, | |
_from, | |
%{current_model: _current_pid, model_name: current_name} = state | |
) | |
when model_name == current_name do | |
# If the requested model is already loaded, do nothing | |
{:reply, :ok, state} | |
end | |
def handle_call({:load_model, model_name}, _from, %{current_model: current_pid} = state) do | |
# Stop the currently running model | |
stop_model(current_pid) | |
{:ok, pid} = start_model(model_name) | |
{:reply, :ok, %{state | current_model: pid, model_name: model_name}} | |
end | |
defp start_model(:whisper) do | |
repo = {:hf, "openai/whisper-medium"} | |
{:ok, model_info} = Bumblebee.load_model(repo) | |
{:ok, featurizer} = Bumblebee.load_featurizer(repo) | |
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) | |
{:ok, generation_config} = Bumblebee.load_generation_config(repo) | |
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 100) | |
serving = | |
Bumblebee.Audio.speech_to_text_whisper( | |
model_info, | |
featurizer, | |
tokenizer, | |
generation_config, | |
compile: [batch_size: 4], | |
chunk_num_seconds: 30, | |
timestamps: :segments, | |
stream: true, | |
defn_options: [compiler: EXLA] | |
) | |
Nx.Serving.start_link(serving: serving, name: :whisper_serving) | |
end | |
defp start_model(:mistral) do | |
repo = | |
{:hf, "mistralai/Mistral-7B-Instruct-v0.1", | |
auth_token: Application.get_env(:app, :hf_auth_token)} | |
{:ok, model_info} = Bumblebee.load_model(repo) | |
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) | |
{:ok, generation_config} = Bumblebee.load_generation_config(repo) | |
serving = | |
Bumblebee.Text.generation( | |
model_info, | |
tokenizer, | |
generation_config, | |
compile: [batch_size: 1, sequence_length: 6000], | |
defn_options: [compiler: EXLA] | |
) | |
Nx.Serving.start_link(serving: serving, name: :mistral_serving) | |
end | |
defp stop_model(pid) do | |
GenServer.stop(pid) | |
:erlang.garbage_collect() | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment