Skip to content

Instantly share code, notes, and snippets.

@lunks
Created November 5, 2024 13:55
Show Gist options
  • Save lunks/fcceaeb13e57dc4e1711003c1907b809 to your computer and use it in GitHub Desktop.
Save lunks/fcceaeb13e57dc4e1711003c1907b809 to your computer and use it in GitHub Desktop.
Switching between models with NX/Bumblebee
defmodule App.ModelManager do
require Logger
use GenServer
def start_link(_) do
GenServer.start_link(__MODULE__, %{}, name: __MODULE__)
end
def init(_) do
state = %{current_model: nil, model_name: nil}
Logger.info("Starting ModelManager...")
{:ok, state}
end
def load_model(model_name) do
GenServer.call(__MODULE__, {:load_model, model_name}, :infinity)
end
def handle_call({:load_model, model_name}, _from, %{current_model: nil} = state) do
{:ok, pid} = start_model(model_name)
{:reply, :ok, %{state | current_model: pid, model_name: model_name}}
end
def handle_call(
{:load_model, model_name},
_from,
%{current_model: _current_pid, model_name: current_name} = state
)
when model_name == current_name do
# If the requested model is already loaded, do nothing
{:reply, :ok, state}
end
def handle_call({:load_model, model_name}, _from, %{current_model: current_pid} = state) do
# Stop the currently running model
stop_model(current_pid)
{:ok, pid} = start_model(model_name)
{:reply, :ok, %{state | current_model: pid, model_name: model_name}}
end
defp start_model(:whisper) do
repo = {:hf, "openai/whisper-medium"}
{:ok, model_info} = Bumblebee.load_model(repo)
{:ok, featurizer} = Bumblebee.load_featurizer(repo)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 100)
serving =
Bumblebee.Audio.speech_to_text_whisper(
model_info,
featurizer,
tokenizer,
generation_config,
compile: [batch_size: 4],
chunk_num_seconds: 30,
timestamps: :segments,
stream: true,
defn_options: [compiler: EXLA]
)
Nx.Serving.start_link(serving: serving, name: :whisper_serving)
end
defp start_model(:mistral) do
repo =
{:hf, "mistralai/Mistral-7B-Instruct-v0.1",
auth_token: Application.get_env(:app, :hf_auth_token)}
{:ok, model_info} = Bumblebee.load_model(repo)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
serving =
Bumblebee.Text.generation(
model_info,
tokenizer,
generation_config,
compile: [batch_size: 1, sequence_length: 6000],
defn_options: [compiler: EXLA]
)
Nx.Serving.start_link(serving: serving, name: :mistral_serving)
end
defp stop_model(pid) do
GenServer.stop(pid)
:erlang.garbage_collect()
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment