{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "eef3ff4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip3.10 install huggingface-hub==0.23\n",
    "# !pip3.10 install git+https://github.com/mesolitica/whisper-static-cache\n",
    "# !pip3.10 uninstall torch -y; pip3.10 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bb0ff692",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f35b7d2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline\n",
    "from transformers.cache_utils import WhisperStaticCache\n",
    "import torch\n",
    "import requests\n",
    "from datasets import Audio\n",
    "from transformers import AutoProcessor\n",
    "from tqdm import tqdm\n",
    "\n",
    "sr = 16000\n",
    "audio = Audio(sampling_rate=sr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "faaa46ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_id = \"openai/whisper-large-v3\"\n",
    "compute_dtype = torch.bfloat16\n",
    "device = \"cuda:0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "70367dfe",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Instantiating WhisperSdpaAttention without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype = compute_dtype) \n",
    "processor = AutoProcessor.from_pretrained(model_id)\n",
    "_ = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0a077da5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_normal = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype = compute_dtype) \n",
    "_ = model_normal.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "deb50faa",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.model.encoder.forward = torch.compile(model.model.encoder.forward, mode='reduce-overhead', fullgraph=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bd05399d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_one_tokens(\n",
    "    model, \n",
    "    proj_out, \n",
    "    cur_token, \n",
    "    past_key_values, \n",
    "    position_ids, \n",
    "    cache_position, \n",
    "    out_encoder,\n",
    "):\n",
    "    \n",
    "    out_decoder = model(\n",
    "        cur_token, \n",
    "        encoder_hidden_states=out_encoder,\n",
    "        past_key_values = past_key_values,\n",
    "        position_ids=position_ids,\n",
    "        use_cache = True,\n",
    "        return_dict = False,\n",
    "        cache_position = cache_position\n",
    "    )\n",
    "    new_token = torch.argmax(proj_out(out_decoder[0][:,-1:]), dim=-1)\n",
    "    return new_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3ea5b98d",
   "metadata": {},
   "outputs": [],
   "source": [
    "decode_one_tokens = torch.compile(decode_one_tokens, mode=\"reduce-overhead\", fullgraph=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6fa3307b",
   "metadata": {},
   "outputs": [],
   "source": [
    "r = requests.get('https://huggingface.co/datasets/huseinzol05/malaya-speech-stt-test-set/resolve/main/test.mp3')\n",
    "y = audio.decode_example(audio.encode_example(r.content))['array']\n",
    "r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/singlish/singlish0.wav')\n",
    "y2 = audio.decode_example(audio.encode_example(r.content))['array']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9b20029a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
     ]
    }
   ],
   "source": [
    "inputs = processor([y], return_tensors = 'pt').to('cuda')\n",
    "inputs['input_features'] = inputs['input_features'].type(torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "be3af0b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
     ]
    }
   ],
   "source": [
    "inputs2 = processor([y2], return_tensors = 'pt').to('cuda')\n",
    "inputs2['input_features'] = inputs2['input_features'].type(torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "96e732a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:393: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.\n",
      "  warnings.warn(\n",
      "/home/husein/.local/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:393: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.\n",
      "  warnings.warn(\n",
      "/home/husein/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2176: UserWarning: Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() before each model invocation\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# warming up\n",
    "\n",
    "for _ in range(3):\n",
    "    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n",
    "        out_encoder = model.model.encoder(inputs['input_features'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "50cc3d9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 38.2 ms, sys: 324 µs, total: 38.5 ms\n",
      "Wall time: 38 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n",
    "    out_encoder = model.model.encoder(inputs['input_features'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "89edf30b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 44.1 ms, sys: 4.01 ms, total: 48.1 ms\n",
      "Wall time: 47.4 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n",
    "    out_encoder = model_normal.model.encoder(inputs['input_features'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "20ff4a3f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2176: UserWarning: Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() before each model invocation\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# warming up\n",
    "\n",
    "with torch.no_grad():\n",
    "    language = 'en'\n",
    "    initial_strings = [\n",
    "        '<|startoftranscript|>',\n",
    "        f'<|{language}|>',\n",
    "        '<|transcribe|>'\n",
    "    ]\n",
    "\n",
    "    labels = processor.tokenizer(\n",
    "        ''.join(initial_strings), \n",
    "        add_special_tokens = False,\n",
    "        return_tensors = 'pt',\n",
    "    ).to('cuda')['input_ids']\n",
    "    out_decoder = model.model.decoder(\n",
    "        labels, \n",
    "        encoder_hidden_states=out_encoder[0],\n",
    "        past_key_values = None,\n",
    "        use_cache = True\n",
    "    )\n",
    "    past_key_values = out_decoder.past_key_values\n",
    "    proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n",
    "    out_encoder = out_encoder[0].clone()\n",
    "    \n",
    "    cache = WhisperStaticCache(model.config, compute_dtype, device, past_key_values)\n",
    "    seq_length = past_key_values[0][0].shape[2]\n",
    "    cache_position = torch.tensor([seq_length], device=device)\n",
    "    position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)\n",
    "    \n",
    "    for i in range(model.config.max_target_positions - len(initial_strings)):\n",
    "        proj = decode_one_tokens(\n",
    "            model.model.decoder, \n",
    "            model.proj_out, \n",
    "            proj.clone(), \n",
    "            cache, \n",
    "            position_ids,\n",
    "            cache_position, \n",
    "            out_encoder\n",
    "        )\n",
    "        labels = torch.concat([labels, proj], axis = -1)\n",
    "        position_ids += 1\n",
    "        cache_position += 1\n",
    "\n",
    "        if proj == model.config.eos_token_id:\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "f6724dfa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 50.5 ms, sys: 56 µs, total: 50.5 ms\n",
      "Wall time: 49.7 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "with torch.no_grad():\n",
    "    \n",
    "    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n",
    "        out_encoder = model.model.encoder(inputs['input_features'])\n",
    "        \n",
    "    language = 'en'\n",
    "    initial_strings = [\n",
    "        '<|startoftranscript|>',\n",
    "        f'<|{language}|>',\n",
    "        '<|transcribe|>'\n",
    "    ]\n",
    "\n",
    "    labels = processor.tokenizer(\n",
    "        ''.join(initial_strings), \n",
    "        add_special_tokens = False,\n",
    "        return_tensors = 'pt',\n",
    "    ).to('cuda')['input_ids']\n",
    "    out_decoder = model.model.decoder(\n",
    "        labels, \n",
    "        encoder_hidden_states=out_encoder[0],\n",
    "        past_key_values = None,\n",
    "        use_cache = True\n",
    "    )\n",
    "    past_key_values = out_decoder.past_key_values\n",
    "    proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n",
    "    out_encoder = out_encoder[0].clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "74ded2b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "cache.reset(existing_cache = past_key_values)\n",
    "seq_length = past_key_values[0][0].shape[2]\n",
    "cache_position = torch.tensor([seq_length], device=device)\n",
    "position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "4014b7fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|████████████▉                                                                  | 73/445 [00:00<00:01, 186.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 396 ms, sys: 3.8 ms, total: 400 ms\n",
      "Wall time: 398 ms\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in tqdm(range(model.config.max_target_positions - len(initial_strings))):\n",
    "        proj = decode_one_tokens(\n",
    "            model.model.decoder, \n",
    "            model.proj_out, \n",
    "            proj.clone(), \n",
    "            cache, \n",
    "            position_ids,\n",
    "            cache_position, \n",
    "            out_encoder\n",
    "        )\n",
    "        labels = torch.concat([labels, proj], axis = -1)\n",
    "        position_ids += 1\n",
    "        cache_position += 1\n",
    "\n",
    "        if proj == model.config.eos_token_id:\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "62dc7314",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<|startoftranscript|><|en|><|transcribe|> Assembly on Aging in Vienna, Australia, which was held in 1982. Based on the evidence, the Ministry of Industry of Malaysia estimated that by 2035, 15% of our population will be from the gold community. For your information, Mr. President and Mr. President, the development of the gold community registration system,<|endoftext|>'"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "processor.tokenizer.decode(labels[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "5d7cc231",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 54.5 ms, sys: 97 µs, total: 54.6 ms\n",
      "Wall time: 54.3 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "with torch.no_grad():\n",
    "    \n",
    "    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n",
    "        out_encoder = model_normal.model.encoder(inputs['input_features'])\n",
    "        \n",
    "    language = 'en'\n",
    "    initial_strings = [\n",
    "        '<|startoftranscript|>',\n",
    "        f'<|{language}|>',\n",
    "        '<|transcribe|>'\n",
    "    ]\n",
    "\n",
    "    labels = processor.tokenizer(\n",
    "        ''.join(initial_strings), \n",
    "        add_special_tokens = False,\n",
    "        return_tensors = 'pt',\n",
    "    ).to('cuda')['input_ids']\n",
    "    out_decoder = model.model.decoder(\n",
    "        labels, \n",
    "        encoder_hidden_states=out_encoder[0],\n",
    "        past_key_values = None,\n",
    "        use_cache = True\n",
    "    )\n",
    "    past_key_values = out_decoder.past_key_values\n",
    "    proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n",
    "    out_encoder = out_encoder[0].clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "202bc451",
   "metadata": {},
   "outputs": [],
   "source": [
    "cache_normal = WhisperStaticCache(model.config, compute_dtype, device, past_key_values)\n",
    "seq_length = past_key_values[0][0].shape[2]\n",
    "cache_position = torch.tensor([seq_length], device=device)\n",
    "position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "b783a51f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|████████████▉                                                                  | 73/445 [00:00<00:02, 150.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 491 ms, sys: 0 ns, total: 491 ms\n",
      "Wall time: 490 ms\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in tqdm(range(model_normal.config.max_target_positions - len(initial_strings))):\n",
    "        out_decoder = model_normal.model.decoder(\n",
    "            proj, \n",
    "            encoder_hidden_states=out_encoder,\n",
    "            past_key_values = cache_normal,\n",
    "            position_ids=position_ids,\n",
    "            use_cache = True,\n",
    "            return_dict = False,\n",
    "            cache_position = cache_position\n",
    "        )\n",
    "        proj = torch.argmax(model_normal.proj_out(out_decoder[0][:,-1:]), dim=-1)\n",
    "        labels = torch.concat([labels, proj], axis = -1)\n",
    "        position_ids += 1\n",
    "        cache_position += 1\n",
    "\n",
    "        if proj == model.config.eos_token_id:\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "19456d76",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<|startoftranscript|><|en|><|transcribe|> Assembly on Aging in Vienna, Australia, which was held in 1982. Based on the evidence, the Ministry of Industry of Malaysia estimated that by 2035, 15% of our population will be from the gold community. For your information, Mr. President and Mr. President, the development of the gold community registration system,<|endoftext|>'"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "processor.tokenizer.decode(labels[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0742e47d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3.10",
   "language": "python",
   "name": "python3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}