-
-
Save NickyDark1/c656638ada0cb1f46c179c23790c8da3 to your computer and use it in GitHub Desktop.
example of whisper static cache
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "eef3ff4e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# !pip3.10 install huggingface-hub==0.23\n", | |
"# !pip3.10 install git+https://github.com/mesolitica/whisper-static-cache\n", | |
"# !pip3.10 uninstall torch -y; pip3.10 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "bb0ff692", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"\n", | |
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "f35b7d2a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/husein/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
" from .autonotebook import tqdm as notebook_tqdm\n" | |
] | |
} | |
], | |
"source": [ | |
"from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline\n", | |
"from transformers.cache_utils import WhisperStaticCache\n", | |
"import torch\n", | |
"import requests\n", | |
"from datasets import Audio\n", | |
"from transformers import AutoProcessor\n", | |
"from tqdm import tqdm\n", | |
"\n", | |
"sr = 16000\n", | |
"audio = Audio(sampling_rate=sr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "faaa46ca", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_id = \"openai/whisper-large-v3\"\n", | |
"compute_dtype = torch.bfloat16\n", | |
"device = \"cuda:0\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "70367dfe", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Instantiating WhisperSdpaAttention without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.\n", | |
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | |
] | |
} | |
], | |
"source": [ | |
"model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype = compute_dtype) \n", | |
"processor = AutoProcessor.from_pretrained(model_id)\n", | |
"_ = model.cuda()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "0a077da5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_normal = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype = compute_dtype) \n", | |
"_ = model_normal.cuda()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "deb50faa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.model.encoder.forward = torch.compile(model.model.encoder.forward, mode='reduce-overhead', fullgraph=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "bd05399d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def decode_one_tokens(\n", | |
" model, \n", | |
" proj_out, \n", | |
" cur_token, \n", | |
" past_key_values, \n", | |
" position_ids, \n", | |
" cache_position, \n", | |
" out_encoder,\n", | |
"):\n", | |
" \n", | |
" out_decoder = model(\n", | |
" cur_token, \n", | |
" encoder_hidden_states=out_encoder,\n", | |
" past_key_values = past_key_values,\n", | |
" position_ids=position_ids,\n", | |
" use_cache = True,\n", | |
" return_dict = False,\n", | |
" cache_position = cache_position\n", | |
" )\n", | |
" new_token = torch.argmax(proj_out(out_decoder[0][:,-1:]), dim=-1)\n", | |
" return new_token" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "3ea5b98d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"decode_one_tokens = torch.compile(decode_one_tokens, mode=\"reduce-overhead\", fullgraph=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "6fa3307b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"r = requests.get('https://huggingface.co/datasets/huseinzol05/malaya-speech-stt-test-set/resolve/main/test.mp3')\n", | |
"y = audio.decode_example(audio.encode_example(r.content))['array']\n", | |
"r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/singlish/singlish0.wav')\n", | |
"y2 = audio.decode_example(audio.encode_example(r.content))['array']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "9b20029a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n" | |
] | |
} | |
], | |
"source": [ | |
"inputs = processor([y], return_tensors = 'pt').to('cuda')\n", | |
"inputs['input_features'] = inputs['input_features'].type(torch.bfloat16)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "be3af0b9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n" | |
] | |
} | |
], | |
"source": [ | |
"inputs2 = processor([y2], return_tensors = 'pt').to('cuda')\n", | |
"inputs2['input_features'] = inputs2['input_features'].type(torch.bfloat16)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "96e732a6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/husein/.local/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:393: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.\n", | |
" warnings.warn(\n", | |
"/home/husein/.local/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:393: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.\n", | |
" warnings.warn(\n", | |
"/home/husein/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2176: UserWarning: Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() before each model invocation\n", | |
" warnings.warn(\n" | |
] | |
} | |
], | |
"source": [ | |
"# warming up\n", | |
"\n", | |
"for _ in range(3):\n", | |
" with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n", | |
" out_encoder = model.model.encoder(inputs['input_features'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "50cc3d9f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 38.2 ms, sys: 324 µs, total: 38.5 ms\n", | |
"Wall time: 38 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n", | |
" out_encoder = model.model.encoder(inputs['input_features'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "89edf30b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 44.1 ms, sys: 4.01 ms, total: 48.1 ms\n", | |
"Wall time: 47.4 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n", | |
" out_encoder = model_normal.model.encoder(inputs['input_features'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "20ff4a3f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/husein/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2176: UserWarning: Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() before each model invocation\n", | |
" warnings.warn(\n" | |
] | |
} | |
], | |
"source": [ | |
"# warming up\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" language = 'en'\n", | |
" initial_strings = [\n", | |
" '<|startoftranscript|>',\n", | |
" f'<|{language}|>',\n", | |
" '<|transcribe|>'\n", | |
" ]\n", | |
"\n", | |
" labels = processor.tokenizer(\n", | |
" ''.join(initial_strings), \n", | |
" add_special_tokens = False,\n", | |
" return_tensors = 'pt',\n", | |
" ).to('cuda')['input_ids']\n", | |
" out_decoder = model.model.decoder(\n", | |
" labels, \n", | |
" encoder_hidden_states=out_encoder[0],\n", | |
" past_key_values = None,\n", | |
" use_cache = True\n", | |
" )\n", | |
" past_key_values = out_decoder.past_key_values\n", | |
" proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n", | |
" out_encoder = out_encoder[0].clone()\n", | |
" \n", | |
" cache = WhisperStaticCache(model.config, compute_dtype, device, past_key_values)\n", | |
" seq_length = past_key_values[0][0].shape[2]\n", | |
" cache_position = torch.tensor([seq_length], device=device)\n", | |
" position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)\n", | |
" \n", | |
" for i in range(model.config.max_target_positions - len(initial_strings)):\n", | |
" proj = decode_one_tokens(\n", | |
" model.model.decoder, \n", | |
" model.proj_out, \n", | |
" proj.clone(), \n", | |
" cache, \n", | |
" position_ids,\n", | |
" cache_position, \n", | |
" out_encoder\n", | |
" )\n", | |
" labels = torch.concat([labels, proj], axis = -1)\n", | |
" position_ids += 1\n", | |
" cache_position += 1\n", | |
"\n", | |
" if proj == model.config.eos_token_id:\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "f6724dfa", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 50.5 ms, sys: 56 µs, total: 50.5 ms\n", | |
"Wall time: 49.7 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" \n", | |
" with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n", | |
" out_encoder = model.model.encoder(inputs['input_features'])\n", | |
" \n", | |
" language = 'en'\n", | |
" initial_strings = [\n", | |
" '<|startoftranscript|>',\n", | |
" f'<|{language}|>',\n", | |
" '<|transcribe|>'\n", | |
" ]\n", | |
"\n", | |
" labels = processor.tokenizer(\n", | |
" ''.join(initial_strings), \n", | |
" add_special_tokens = False,\n", | |
" return_tensors = 'pt',\n", | |
" ).to('cuda')['input_ids']\n", | |
" out_decoder = model.model.decoder(\n", | |
" labels, \n", | |
" encoder_hidden_states=out_encoder[0],\n", | |
" past_key_values = None,\n", | |
" use_cache = True\n", | |
" )\n", | |
" past_key_values = out_decoder.past_key_values\n", | |
" proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n", | |
" out_encoder = out_encoder[0].clone()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "74ded2b3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cache.reset(existing_cache = past_key_values)\n", | |
"seq_length = past_key_values[0][0].shape[2]\n", | |
"cache_position = torch.tensor([seq_length], device=device)\n", | |
"position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"id": "4014b7fb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 16%|████████████▉ | 73/445 [00:00<00:01, 186.26it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 396 ms, sys: 3.8 ms, total: 400 ms\n", | |
"Wall time: 398 ms\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" for i in tqdm(range(model.config.max_target_positions - len(initial_strings))):\n", | |
" proj = decode_one_tokens(\n", | |
" model.model.decoder, \n", | |
" model.proj_out, \n", | |
" proj.clone(), \n", | |
" cache, \n", | |
" position_ids,\n", | |
" cache_position, \n", | |
" out_encoder\n", | |
" )\n", | |
" labels = torch.concat([labels, proj], axis = -1)\n", | |
" position_ids += 1\n", | |
" cache_position += 1\n", | |
"\n", | |
" if proj == model.config.eos_token_id:\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"id": "62dc7314", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'<|startoftranscript|><|en|><|transcribe|> Assembly on Aging in Vienna, Australia, which was held in 1982. Based on the evidence, the Ministry of Industry of Malaysia estimated that by 2035, 15% of our population will be from the gold community. For your information, Mr. President and Mr. President, the development of the gold community registration system,<|endoftext|>'" | |
] | |
}, | |
"execution_count": 36, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"processor.tokenizer.decode(labels[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"id": "5d7cc231", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 54.5 ms, sys: 97 µs, total: 54.6 ms\n", | |
"Wall time: 54.3 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" \n", | |
" with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n", | |
" out_encoder = model_normal.model.encoder(inputs['input_features'])\n", | |
" \n", | |
" language = 'en'\n", | |
" initial_strings = [\n", | |
" '<|startoftranscript|>',\n", | |
" f'<|{language}|>',\n", | |
" '<|transcribe|>'\n", | |
" ]\n", | |
"\n", | |
" labels = processor.tokenizer(\n", | |
" ''.join(initial_strings), \n", | |
" add_special_tokens = False,\n", | |
" return_tensors = 'pt',\n", | |
" ).to('cuda')['input_ids']\n", | |
" out_decoder = model.model.decoder(\n", | |
" labels, \n", | |
" encoder_hidden_states=out_encoder[0],\n", | |
" past_key_values = None,\n", | |
" use_cache = True\n", | |
" )\n", | |
" past_key_values = out_decoder.past_key_values\n", | |
" proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n", | |
" out_encoder = out_encoder[0].clone()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"id": "202bc451", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cache_normal = WhisperStaticCache(model.config, compute_dtype, device, past_key_values)\n", | |
"seq_length = past_key_values[0][0].shape[2]\n", | |
"cache_position = torch.tensor([seq_length], device=device)\n", | |
"position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"id": "b783a51f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 16%|████████████▉ | 73/445 [00:00<00:02, 150.20it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 491 ms, sys: 0 ns, total: 491 ms\n", | |
"Wall time: 490 ms\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" for i in tqdm(range(model_normal.config.max_target_positions - len(initial_strings))):\n", | |
" out_decoder = model_normal.model.decoder(\n", | |
" proj, \n", | |
" encoder_hidden_states=out_encoder,\n", | |
" past_key_values = cache_normal,\n", | |
" position_ids=position_ids,\n", | |
" use_cache = True,\n", | |
" return_dict = False,\n", | |
" cache_position = cache_position\n", | |
" )\n", | |
" proj = torch.argmax(model_normal.proj_out(out_decoder[0][:,-1:]), dim=-1)\n", | |
" labels = torch.concat([labels, proj], axis = -1)\n", | |
" position_ids += 1\n", | |
" cache_position += 1\n", | |
"\n", | |
" if proj == model.config.eos_token_id:\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"id": "19456d76", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'<|startoftranscript|><|en|><|transcribe|> Assembly on Aging in Vienna, Australia, which was held in 1982. Based on the evidence, the Ministry of Industry of Malaysia estimated that by 2035, 15% of our population will be from the gold community. For your information, Mr. President and Mr. President, the development of the gold community registration system,<|endoftext|>'" | |
] | |
}, | |
"execution_count": 44, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"processor.tokenizer.decode(labels[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0742e47d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "python3.10", | |
"language": "python", | |
"name": "python3.10" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment