{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "eef3ff4e", "metadata": {}, "outputs": [], "source": [ "# !pip3.10 install huggingface-hub==0.23\n", "# !pip3.10 install git+https://github.com/mesolitica/whisper-static-cache\n", "# !pip3.10 uninstall torch -y; pip3.10 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121" ] }, { "cell_type": "code", "execution_count": 2, "id": "bb0ff692", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" ] }, { "cell_type": "code", "execution_count": 3, "id": "f35b7d2a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/husein/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline\n", "from transformers.cache_utils import WhisperStaticCache\n", "import torch\n", "import requests\n", "from datasets import Audio\n", "from transformers import AutoProcessor\n", "from tqdm import tqdm\n", "\n", "sr = 16000\n", "audio = Audio(sampling_rate=sr)" ] }, { "cell_type": "code", "execution_count": 4, "id": "faaa46ca", "metadata": {}, "outputs": [], "source": [ "model_id = \"openai/whisper-large-v3\"\n", "compute_dtype = torch.bfloat16\n", "device = \"cuda:0\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "70367dfe", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Instantiating WhisperSdpaAttention without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype = compute_dtype) \n", "processor = AutoProcessor.from_pretrained(model_id)\n", "_ = model.cuda()" ] }, { "cell_type": "code", "execution_count": 6, "id": "0a077da5", "metadata": {}, "outputs": [], "source": [ "model_normal = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype = compute_dtype) \n", "_ = model_normal.cuda()" ] }, { "cell_type": "code", "execution_count": 7, "id": "deb50faa", "metadata": {}, "outputs": [], "source": [ "model.model.encoder.forward = torch.compile(model.model.encoder.forward, mode='reduce-overhead', fullgraph=True)" ] }, { "cell_type": "code", "execution_count": 8, "id": "bd05399d", "metadata": {}, "outputs": [], "source": [ "def decode_one_tokens(\n", " model, \n", " proj_out, \n", " cur_token, \n", " past_key_values, \n", " position_ids, \n", " cache_position, \n", " out_encoder,\n", "):\n", " \n", " out_decoder = model(\n", " cur_token, \n", " encoder_hidden_states=out_encoder,\n", " past_key_values = past_key_values,\n", " position_ids=position_ids,\n", " use_cache = True,\n", " return_dict = False,\n", " cache_position = cache_position\n", " )\n", " new_token = torch.argmax(proj_out(out_decoder[0][:,-1:]), dim=-1)\n", " return new_token" ] }, { "cell_type": "code", "execution_count": 9, "id": "3ea5b98d", "metadata": {}, "outputs": [], "source": [ "decode_one_tokens = torch.compile(decode_one_tokens, mode=\"reduce-overhead\", fullgraph=True)" ] }, { "cell_type": "code", "execution_count": 10, "id": "6fa3307b", "metadata": {}, "outputs": [], "source": [ "r = requests.get('https://huggingface.co/datasets/huseinzol05/malaya-speech-stt-test-set/resolve/main/test.mp3')\n", "y = audio.decode_example(audio.encode_example(r.content))['array']\n", "r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/singlish/singlish0.wav')\n", "y2 = audio.decode_example(audio.encode_example(r.content))['array']" ] }, { "cell_type": "code", "execution_count": 11, "id": "9b20029a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n" ] } ], "source": [ "inputs = processor([y], return_tensors = 'pt').to('cuda')\n", "inputs['input_features'] = inputs['input_features'].type(torch.bfloat16)" ] }, { "cell_type": "code", "execution_count": 12, "id": "be3af0b9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n" ] } ], "source": [ "inputs2 = processor([y2], return_tensors = 'pt').to('cuda')\n", "inputs2['input_features'] = inputs2['input_features'].type(torch.bfloat16)" ] }, { "cell_type": "code", "execution_count": 13, "id": "96e732a6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/husein/.local/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:393: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.\n", " warnings.warn(\n", "/home/husein/.local/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:393: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.\n", " warnings.warn(\n", "/home/husein/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2176: UserWarning: Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() before each model invocation\n", " warnings.warn(\n" ] } ], "source": [ "# warming up\n", "\n", "for _ in range(3):\n", " with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n", " out_encoder = model.model.encoder(inputs['input_features'])" ] }, { "cell_type": "code", "execution_count": 14, "id": "50cc3d9f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 38.2 ms, sys: 324 µs, total: 38.5 ms\n", "Wall time: 38 ms\n" ] } ], "source": [ "%%time\n", "\n", "with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n", " out_encoder = model.model.encoder(inputs['input_features'])" ] }, { "cell_type": "code", "execution_count": 15, "id": "89edf30b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 44.1 ms, sys: 4.01 ms, total: 48.1 ms\n", "Wall time: 47.4 ms\n" ] } ], "source": [ "%%time\n", "\n", "with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n", " out_encoder = model_normal.model.encoder(inputs['input_features'])" ] }, { "cell_type": "code", "execution_count": 16, "id": "20ff4a3f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/husein/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2176: UserWarning: Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() before each model invocation\n", " warnings.warn(\n" ] } ], "source": [ "# warming up\n", "\n", "with torch.no_grad():\n", " language = 'en'\n", " initial_strings = [\n", " '<|startoftranscript|>',\n", " f'<|{language}|>',\n", " '<|transcribe|>'\n", " ]\n", "\n", " labels = processor.tokenizer(\n", " ''.join(initial_strings), \n", " add_special_tokens = False,\n", " return_tensors = 'pt',\n", " ).to('cuda')['input_ids']\n", " out_decoder = model.model.decoder(\n", " labels, \n", " encoder_hidden_states=out_encoder[0],\n", " past_key_values = None,\n", " use_cache = True\n", " )\n", " past_key_values = out_decoder.past_key_values\n", " proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n", " out_encoder = out_encoder[0].clone()\n", " \n", " cache = WhisperStaticCache(model.config, compute_dtype, device, past_key_values)\n", " seq_length = past_key_values[0][0].shape[2]\n", " cache_position = torch.tensor([seq_length], device=device)\n", " position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)\n", " \n", " for i in range(model.config.max_target_positions - len(initial_strings)):\n", " proj = decode_one_tokens(\n", " model.model.decoder, \n", " model.proj_out, \n", " proj.clone(), \n", " cache, \n", " position_ids,\n", " cache_position, \n", " out_encoder\n", " )\n", " labels = torch.concat([labels, proj], axis = -1)\n", " position_ids += 1\n", " cache_position += 1\n", "\n", " if proj == model.config.eos_token_id:\n", " break" ] }, { "cell_type": "code", "execution_count": 33, "id": "f6724dfa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 50.5 ms, sys: 56 µs, total: 50.5 ms\n", "Wall time: 49.7 ms\n" ] } ], "source": [ "%%time\n", "\n", "with torch.no_grad():\n", " \n", " with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n", " out_encoder = model.model.encoder(inputs['input_features'])\n", " \n", " language = 'en'\n", " initial_strings = [\n", " '<|startoftranscript|>',\n", " f'<|{language}|>',\n", " '<|transcribe|>'\n", " ]\n", "\n", " labels = processor.tokenizer(\n", " ''.join(initial_strings), \n", " add_special_tokens = False,\n", " return_tensors = 'pt',\n", " ).to('cuda')['input_ids']\n", " out_decoder = model.model.decoder(\n", " labels, \n", " encoder_hidden_states=out_encoder[0],\n", " past_key_values = None,\n", " use_cache = True\n", " )\n", " past_key_values = out_decoder.past_key_values\n", " proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n", " out_encoder = out_encoder[0].clone()" ] }, { "cell_type": "code", "execution_count": 34, "id": "74ded2b3", "metadata": {}, "outputs": [], "source": [ "cache.reset(existing_cache = past_key_values)\n", "seq_length = past_key_values[0][0].shape[2]\n", "cache_position = torch.tensor([seq_length], device=device)\n", "position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)" ] }, { "cell_type": "code", "execution_count": 35, "id": "4014b7fb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 16%|████████████▉ | 73/445 [00:00<00:01, 186.26it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 396 ms, sys: 3.8 ms, total: 400 ms\n", "Wall time: 398 ms\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "%%time\n", "\n", "with torch.no_grad():\n", " for i in tqdm(range(model.config.max_target_positions - len(initial_strings))):\n", " proj = decode_one_tokens(\n", " model.model.decoder, \n", " model.proj_out, \n", " proj.clone(), \n", " cache, \n", " position_ids,\n", " cache_position, \n", " out_encoder\n", " )\n", " labels = torch.concat([labels, proj], axis = -1)\n", " position_ids += 1\n", " cache_position += 1\n", "\n", " if proj == model.config.eos_token_id:\n", " break" ] }, { "cell_type": "code", "execution_count": 36, "id": "62dc7314", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'<|startoftranscript|><|en|><|transcribe|> Assembly on Aging in Vienna, Australia, which was held in 1982. Based on the evidence, the Ministry of Industry of Malaysia estimated that by 2035, 15% of our population will be from the gold community. For your information, Mr. President and Mr. President, the development of the gold community registration system,<|endoftext|>'" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "processor.tokenizer.decode(labels[0])" ] }, { "cell_type": "code", "execution_count": 41, "id": "5d7cc231", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 54.5 ms, sys: 97 µs, total: 54.6 ms\n", "Wall time: 54.3 ms\n" ] } ], "source": [ "%%time\n", "\n", "with torch.no_grad():\n", " \n", " with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n", " out_encoder = model_normal.model.encoder(inputs['input_features'])\n", " \n", " language = 'en'\n", " initial_strings = [\n", " '<|startoftranscript|>',\n", " f'<|{language}|>',\n", " '<|transcribe|>'\n", " ]\n", "\n", " labels = processor.tokenizer(\n", " ''.join(initial_strings), \n", " add_special_tokens = False,\n", " return_tensors = 'pt',\n", " ).to('cuda')['input_ids']\n", " out_decoder = model.model.decoder(\n", " labels, \n", " encoder_hidden_states=out_encoder[0],\n", " past_key_values = None,\n", " use_cache = True\n", " )\n", " past_key_values = out_decoder.past_key_values\n", " proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n", " out_encoder = out_encoder[0].clone()" ] }, { "cell_type": "code", "execution_count": 42, "id": "202bc451", "metadata": {}, "outputs": [], "source": [ "cache_normal = WhisperStaticCache(model.config, compute_dtype, device, past_key_values)\n", "seq_length = past_key_values[0][0].shape[2]\n", "cache_position = torch.tensor([seq_length], device=device)\n", "position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)" ] }, { "cell_type": "code", "execution_count": 43, "id": "b783a51f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 16%|████████████▉ | 73/445 [00:00<00:02, 150.20it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 491 ms, sys: 0 ns, total: 491 ms\n", "Wall time: 490 ms\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "%%time\n", "\n", "with torch.no_grad():\n", " for i in tqdm(range(model_normal.config.max_target_positions - len(initial_strings))):\n", " out_decoder = model_normal.model.decoder(\n", " proj, \n", " encoder_hidden_states=out_encoder,\n", " past_key_values = cache_normal,\n", " position_ids=position_ids,\n", " use_cache = True,\n", " return_dict = False,\n", " cache_position = cache_position\n", " )\n", " proj = torch.argmax(model_normal.proj_out(out_decoder[0][:,-1:]), dim=-1)\n", " labels = torch.concat([labels, proj], axis = -1)\n", " position_ids += 1\n", " cache_position += 1\n", "\n", " if proj == model.config.eos_token_id:\n", " break" ] }, { "cell_type": "code", "execution_count": 44, "id": "19456d76", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'<|startoftranscript|><|en|><|transcribe|> Assembly on Aging in Vienna, Australia, which was held in 1982. Based on the evidence, the Ministry of Industry of Malaysia estimated that by 2035, 15% of our population will be from the gold community. For your information, Mr. President and Mr. President, the development of the gold community registration system,<|endoftext|>'" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "processor.tokenizer.decode(labels[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "0742e47d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "python3.10", "language": "python", "name": "python3.10" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }