Skip to content

Instantly share code, notes, and snippets.

@m3hrdadfi
Last active June 3, 2022 13:29
Show Gist options
  • Save m3hrdadfi/70aab19da83646884669c06fac3ec12c to your computer and use it in GitHub Desktop.
Save m3hrdadfi/70aab19da83646884669c06fac3ec12c to your computer and use it in GitHub Desktop.
Word Pooling
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"id": "e7ce4d79-6fe7-4044-b9f2-efbabb524d27",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import transformers\n",
"from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn import CrossEntropyLoss\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"from torch.utils.data import DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "837cc158-fa5c-402f-aa7c-22e9e130a1ee",
"metadata": {},
"outputs": [],
"source": [
"model_name = \"bert-base-uncased\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "9e2b99e8-8988-4723-87f4-0528d561292d",
"metadata": {},
"outputs": [],
"source": [
"def build_input(\n",
" tokenizer,\n",
" text,\n",
" label=None,\n",
" max_len=512,\n",
" ignore_index=-100\n",
"):\n",
" \"\"\"\n",
" Converting each text into a dictionary of `input_ids`, and `labels` if it exists\n",
"\n",
" Example:\n",
" >>> print(datamodule.build_input(datamodule.tokenizer, \"It is good to see you.\", 1))\n",
" {'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
" 'input_ids': [101, 1135, 1110, 1363, 1106, 1267, 1128, 119, 102],\n",
" 'labels': [1],\n",
" 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" 'word_ids': [-100, 0, 1, 2, 3, 4, 5, 6, -100],\n",
" 'word_offsets': [1, 8]}\n",
" \"\"\"\n",
" print(tokenizer.tokenize(text))\n",
" inputs = tokenizer(text, padding=False, truncation=True)\n",
"\n",
" instance = {k: v for k, v in inputs.items()}\n",
"\n",
" word_ids = inputs.word_ids(0)\n",
" \n",
" instance[\"word_ids\"] = list(map(lambda v: v if v is not None else ignore_index, word_ids))\n",
" instance[\"word_offsets\"] = [1, np.where(np.array(word_ids) == None)[0].tolist()[1]]\n",
"\n",
" if label is not None:\n",
" instance[\"labels\"] = [label]\n",
"\n",
" return instance"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "b7ee69fa-89bb-43b4-8919-03ecd3a9f57f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['i', 'like', 'yu', '##or', 'dog', '.']\n",
"['the', 'lo', '##v', '##liest', '.']\n"
]
},
{
"data": {
"text/plain": [
"[{'input_ids': [101, 1045, 2066, 9805, 2953, 3899, 1012, 102],\n",
" 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],\n",
" 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1],\n",
" 'word_ids': [-100, 0, 1, 2, 2, 3, 4, -100],\n",
" 'word_offsets': [1, 7]},\n",
" {'input_ids': [101, 1996, 8840, 2615, 21292, 1012, 102],\n",
" 'token_type_ids': [0, 0, 0, 0, 0, 0, 0],\n",
" 'attention_mask': [1, 1, 1, 1, 1, 1, 1],\n",
" 'word_ids': [-100, 0, 1, 1, 1, 2, -100],\n",
" 'word_offsets': [1, 6]}]"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"texts = [\"i like yuor dog.\", \"the lovliest.\"]\n",
"data = [] \n",
"for text in texts:\n",
" data.append(build_input(tokenizer, text))\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "8dced84c-09d1-4dfe-931c-8146f0759091",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"{'input_ids': tensor([[ 101, 1045, 2066, 9805, 2953, 3899, 1012, 102],\n",
" [ 101, 1996, 8840, 2615, 21292, 1012, 102, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1],\n",
" [1, 1, 1, 1, 1, 1, 1, 0]]), 'word_ids': tensor([[-100, 0, 1, 2, 2, 3, 4, -100],\n",
" [-100, 0, 1, 1, 1, 2, -100, -100]]), 'word_offsets': tensor([[1, 7],\n",
" [1, 6]])}\n"
]
}
],
"source": [
"def collator_with_padding(\n",
" padding_index=0,\n",
" ignore_index=-100,\n",
" ignore_keys=None\n",
"):\n",
" if ignore_keys is None:\n",
" ignore_keys = []\n",
"\n",
" def collate_fn(features):\n",
" batch = {\n",
" key: pad_sequence([torch.tensor(i[key]) for i in features], batch_first=True, padding_value=padding_index)\n",
" if key not in ignore_keys else\n",
" pad_sequence([torch.tensor(i[key]) for i in features], batch_first=True, padding_value=ignore_index)\n",
" for key in features[0]\n",
" }\n",
"\n",
" return batch\n",
"\n",
" return collate_fn\n",
"\n",
"\n",
"dl = DataLoader(\n",
" data,\n",
" batch_size=2,\n",
" shuffle=False,\n",
" num_workers=1,\n",
" collate_fn=collator_with_padding(padding_index=0, ignore_index=-100, ignore_keys=[\"word_ids\"])\n",
")\n",
"for d in dl:\n",
" print(d)"
]
},
{
"cell_type": "code",
"execution_count": 107,
"id": "2e6642a9-bd13-4895-a634-6b569b87a12d",
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(\n",
" self,\n",
" model_name_or_path,\n",
" ):\n",
" super().__init__()\n",
" self.transformer = AutoModel.from_pretrained(model_name_or_path)\n",
"\n",
" def word_level_pooling(self, hidden_state, word_id, offset, mode=\"mean\"):\n",
" \"\"\"\n",
" Instead of pooling over all tokens, we're going to accumalate the states of each subtokens respect to its word and then pool over those states\n",
" using `word_ids` and `offset` variables.\n",
" \n",
" Example:\n",
" We have a sentence with nine words and the word at the middle consists of two subtokens (T: [5, 6]) in total the output must be something like 1x10x758.\n",
" instea of pooling (mean) over this output we first segmentize subtokens in one group, then pool over them and\n",
" a final pooling (mean/sum/max) overal.\n",
" T: [5, 6]: 1x10x758 --> 1x9x758 --> 1x758 \n",
" \"\"\"\n",
" print(f\"word_id: {word_id}\")\n",
" idx = word_id[word_id > -1]\n",
" print(f\"idx: {idx}\")\n",
" _, idx_unique_count = idx.unique(dim=0, return_counts=True)\n",
" print(_)\n",
" print(f\"idx_unique_count: {idx_unique_count}\")\n",
" hidden_state = hidden_state[offset[0]:offset[1], :]\n",
" print(f\"hidden_state.shape: {hidden_state.shape}\")\n",
"\n",
" hidden_state_split = torch.split(hidden_state, idx_unique_count.tolist(), dim=0)\n",
" print(\"hidden_state_split:\", [hss.shape for hss in hidden_state_split])\n",
" hidden_state_padded = pad_sequence(hidden_state_split, batch_first=True, padding_value=0)\n",
" print(\"hidden_state_padded:\", [hsp.shape for hsp in hidden_state_padded])\n",
" \n",
" # Getting average over sub tokens\n",
" output = torch.div(\n",
" torch.sum(hidden_state_padded, dim=1),\n",
" idx_unique_count.reshape(-1, 1)\n",
" )\n",
" print(f\"output.shape: {output.shape}\")\n",
" print()\n",
" return output\n",
"\n",
" def forward(\n",
" self,\n",
" input_ids=None,\n",
" attention_mask=None,\n",
" token_type_ids=None,\n",
" position_ids=None,\n",
" head_mask=None,\n",
" inputs_embeds=None,\n",
" word_ids=None,\n",
" word_offsets=None,\n",
" labels=None,\n",
" output_attentions=None,\n",
" output_hidden_states=None,\n",
" return_dict=None,\n",
" ):\n",
" \"\"\"\n",
" Feed data into our model.\n",
" \"\"\"\n",
"\n",
" return_dict = True\n",
"\n",
" outputs = self.transformer(\n",
" input_ids,\n",
" attention_mask=attention_mask,\n",
" token_type_ids=token_type_ids,\n",
" position_ids=position_ids,\n",
" head_mask=head_mask,\n",
" inputs_embeds=inputs_embeds,\n",
" output_attentions=output_attentions,\n",
" output_hidden_states=output_hidden_states,\n",
" return_dict=return_dict,\n",
" )\n",
"\n",
" hidden_states = outputs[0]\n",
" pooled_output = pad_sequence(\n",
" [\n",
" self.word_level_pooling(hidden_state, word_id, offset)\n",
" for hidden_state, word_id, offset in zip(hidden_states, word_ids, word_offsets)\n",
" ],\n",
" batch_first=True,\n",
" padding_value=0\n",
" )\n",
"\n",
" return pooled_output"
]
},
{
"cell_type": "code",
"execution_count": 108,
"id": "350119cf-b51d-46cc-bb32-0b7d4db13bad",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"model = Model(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 109,
"id": "b6ad5113-39d4-426f-95cc-a6694d05ff5d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"word_id: tensor([-100, 0, 1, 2, 2, 3, 4, -100])\n",
"idx: tensor([0, 1, 2, 2, 3, 4])\n",
"tensor([0, 1, 2, 3, 4])\n",
"idx_unique_count: tensor([1, 1, 2, 1, 1])\n",
"hidden_state.shape: torch.Size([6, 768])\n",
"hidden_state_split: [torch.Size([1, 768]), torch.Size([1, 768]), torch.Size([2, 768]), torch.Size([1, 768]), torch.Size([1, 768])]\n",
"hidden_state_padded: [torch.Size([2, 768]), torch.Size([2, 768]), torch.Size([2, 768]), torch.Size([2, 768]), torch.Size([2, 768])]\n",
"output.shape: torch.Size([5, 768])\n",
"\n",
"word_id: tensor([-100, 0, 1, 1, 1, 2, -100, -100])\n",
"idx: tensor([0, 1, 1, 1, 2])\n",
"tensor([0, 1, 2])\n",
"idx_unique_count: tensor([1, 3, 1])\n",
"hidden_state.shape: torch.Size([5, 768])\n",
"hidden_state_split: [torch.Size([1, 768]), torch.Size([3, 768]), torch.Size([1, 768])]\n",
"hidden_state_padded: [torch.Size([3, 768]), torch.Size([3, 768]), torch.Size([3, 768])]\n",
"output.shape: torch.Size([3, 768])\n",
"\n"
]
},
{
"data": {
"text/plain": [
"torch.Tensor"
]
},
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"for d in dl:\n",
" outputs = model(**d)\n",
" \n",
"type(outputs)"
]
},
{
"cell_type": "code",
"execution_count": 111,
"id": "04ba70e1-1d10-4c2b-a463-d8f50c11bcbe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 5, 768])"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"outputs.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc1d9368-4065-4821-b721-e78ab6690d84",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:ml]",
"language": "python",
"name": "conda-env-ml-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment