Last active
June 3, 2022 13:29
-
-
Save m3hrdadfi/70aab19da83646884669c06fac3ec12c to your computer and use it in GitHub Desktop.
Word Pooling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "e7ce4d79-6fe7-4044-b9f2-efbabb524d27", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import transformers\n", | |
"from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from torch.nn import CrossEntropyLoss\n", | |
"from torch.nn.utils.rnn import pad_sequence\n", | |
"from torch.utils.data import DataLoader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "837cc158-fa5c-402f-aa7c-22e9e130a1ee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_name = \"bert-base-uncased\"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"id": "9e2b99e8-8988-4723-87f4-0528d561292d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def build_input(\n", | |
" tokenizer,\n", | |
" text,\n", | |
" label=None,\n", | |
" max_len=512,\n", | |
" ignore_index=-100\n", | |
"):\n", | |
" \"\"\"\n", | |
" Converting each text into a dictionary of `input_ids`, and `labels` if it exists\n", | |
"\n", | |
" Example:\n", | |
" >>> print(datamodule.build_input(datamodule.tokenizer, \"It is good to see you.\", 1))\n", | |
" {'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1],\n", | |
" 'input_ids': [101, 1135, 1110, 1363, 1106, 1267, 1128, 119, 102],\n", | |
" 'labels': [1],\n", | |
" 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" 'word_ids': [-100, 0, 1, 2, 3, 4, 5, 6, -100],\n", | |
" 'word_offsets': [1, 8]}\n", | |
" \"\"\"\n", | |
" print(tokenizer.tokenize(text))\n", | |
" inputs = tokenizer(text, padding=False, truncation=True)\n", | |
"\n", | |
" instance = {k: v for k, v in inputs.items()}\n", | |
"\n", | |
" word_ids = inputs.word_ids(0)\n", | |
" \n", | |
" instance[\"word_ids\"] = list(map(lambda v: v if v is not None else ignore_index, word_ids))\n", | |
" instance[\"word_offsets\"] = [1, np.where(np.array(word_ids) == None)[0].tolist()[1]]\n", | |
"\n", | |
" if label is not None:\n", | |
" instance[\"labels\"] = [label]\n", | |
"\n", | |
" return instance" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 66, | |
"id": "b7ee69fa-89bb-43b4-8919-03ecd3a9f57f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['i', 'like', 'yu', '##or', 'dog', '.']\n", | |
"['the', 'lo', '##v', '##liest', '.']\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[{'input_ids': [101, 1045, 2066, 9805, 2953, 3899, 1012, 102],\n", | |
" 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1],\n", | |
" 'word_ids': [-100, 0, 1, 2, 2, 3, 4, -100],\n", | |
" 'word_offsets': [1, 7]},\n", | |
" {'input_ids': [101, 1996, 8840, 2615, 21292, 1012, 102],\n", | |
" 'token_type_ids': [0, 0, 0, 0, 0, 0, 0],\n", | |
" 'attention_mask': [1, 1, 1, 1, 1, 1, 1],\n", | |
" 'word_ids': [-100, 0, 1, 1, 1, 2, -100],\n", | |
" 'word_offsets': [1, 6]}]" | |
] | |
}, | |
"execution_count": 66, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"texts = [\"i like yuor dog.\", \"the lovliest.\"]\n", | |
"data = [] \n", | |
"for text in texts:\n", | |
" data.append(build_input(tokenizer, text))\n", | |
"data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 67, | |
"id": "8dced84c-09d1-4dfe-931c-8146f0759091", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | |
"To disable this warning, you can either:\n", | |
"\t- Avoid using `tokenizers` before the fork if possible\n", | |
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | |
"{'input_ids': tensor([[ 101, 1045, 2066, 9805, 2953, 3899, 1012, 102],\n", | |
" [ 101, 1996, 8840, 2615, 21292, 1012, 102, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0],\n", | |
" [0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1],\n", | |
" [1, 1, 1, 1, 1, 1, 1, 0]]), 'word_ids': tensor([[-100, 0, 1, 2, 2, 3, 4, -100],\n", | |
" [-100, 0, 1, 1, 1, 2, -100, -100]]), 'word_offsets': tensor([[1, 7],\n", | |
" [1, 6]])}\n" | |
] | |
} | |
], | |
"source": [ | |
"def collator_with_padding(\n", | |
" padding_index=0,\n", | |
" ignore_index=-100,\n", | |
" ignore_keys=None\n", | |
"):\n", | |
" if ignore_keys is None:\n", | |
" ignore_keys = []\n", | |
"\n", | |
" def collate_fn(features):\n", | |
" batch = {\n", | |
" key: pad_sequence([torch.tensor(i[key]) for i in features], batch_first=True, padding_value=padding_index)\n", | |
" if key not in ignore_keys else\n", | |
" pad_sequence([torch.tensor(i[key]) for i in features], batch_first=True, padding_value=ignore_index)\n", | |
" for key in features[0]\n", | |
" }\n", | |
"\n", | |
" return batch\n", | |
"\n", | |
" return collate_fn\n", | |
"\n", | |
"\n", | |
"dl = DataLoader(\n", | |
" data,\n", | |
" batch_size=2,\n", | |
" shuffle=False,\n", | |
" num_workers=1,\n", | |
" collate_fn=collator_with_padding(padding_index=0, ignore_index=-100, ignore_keys=[\"word_ids\"])\n", | |
")\n", | |
"for d in dl:\n", | |
" print(d)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 107, | |
"id": "2e6642a9-bd13-4895-a634-6b569b87a12d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Model(nn.Module):\n", | |
" def __init__(\n", | |
" self,\n", | |
" model_name_or_path,\n", | |
" ):\n", | |
" super().__init__()\n", | |
" self.transformer = AutoModel.from_pretrained(model_name_or_path)\n", | |
"\n", | |
" def word_level_pooling(self, hidden_state, word_id, offset, mode=\"mean\"):\n", | |
" \"\"\"\n", | |
" Instead of pooling over all tokens, we're going to accumalate the states of each subtokens respect to its word and then pool over those states\n", | |
" using `word_ids` and `offset` variables.\n", | |
" \n", | |
" Example:\n", | |
" We have a sentence with nine words and the word at the middle consists of two subtokens (T: [5, 6]) in total the output must be something like 1x10x758.\n", | |
" instea of pooling (mean) over this output we first segmentize subtokens in one group, then pool over them and\n", | |
" a final pooling (mean/sum/max) overal.\n", | |
" T: [5, 6]: 1x10x758 --> 1x9x758 --> 1x758 \n", | |
" \"\"\"\n", | |
" print(f\"word_id: {word_id}\")\n", | |
" idx = word_id[word_id > -1]\n", | |
" print(f\"idx: {idx}\")\n", | |
" _, idx_unique_count = idx.unique(dim=0, return_counts=True)\n", | |
" print(_)\n", | |
" print(f\"idx_unique_count: {idx_unique_count}\")\n", | |
" hidden_state = hidden_state[offset[0]:offset[1], :]\n", | |
" print(f\"hidden_state.shape: {hidden_state.shape}\")\n", | |
"\n", | |
" hidden_state_split = torch.split(hidden_state, idx_unique_count.tolist(), dim=0)\n", | |
" print(\"hidden_state_split:\", [hss.shape for hss in hidden_state_split])\n", | |
" hidden_state_padded = pad_sequence(hidden_state_split, batch_first=True, padding_value=0)\n", | |
" print(\"hidden_state_padded:\", [hsp.shape for hsp in hidden_state_padded])\n", | |
" \n", | |
" # Getting average over sub tokens\n", | |
" output = torch.div(\n", | |
" torch.sum(hidden_state_padded, dim=1),\n", | |
" idx_unique_count.reshape(-1, 1)\n", | |
" )\n", | |
" print(f\"output.shape: {output.shape}\")\n", | |
" print()\n", | |
" return output\n", | |
"\n", | |
" def forward(\n", | |
" self,\n", | |
" input_ids=None,\n", | |
" attention_mask=None,\n", | |
" token_type_ids=None,\n", | |
" position_ids=None,\n", | |
" head_mask=None,\n", | |
" inputs_embeds=None,\n", | |
" word_ids=None,\n", | |
" word_offsets=None,\n", | |
" labels=None,\n", | |
" output_attentions=None,\n", | |
" output_hidden_states=None,\n", | |
" return_dict=None,\n", | |
" ):\n", | |
" \"\"\"\n", | |
" Feed data into our model.\n", | |
" \"\"\"\n", | |
"\n", | |
" return_dict = True\n", | |
"\n", | |
" outputs = self.transformer(\n", | |
" input_ids,\n", | |
" attention_mask=attention_mask,\n", | |
" token_type_ids=token_type_ids,\n", | |
" position_ids=position_ids,\n", | |
" head_mask=head_mask,\n", | |
" inputs_embeds=inputs_embeds,\n", | |
" output_attentions=output_attentions,\n", | |
" output_hidden_states=output_hidden_states,\n", | |
" return_dict=return_dict,\n", | |
" )\n", | |
"\n", | |
" hidden_states = outputs[0]\n", | |
" pooled_output = pad_sequence(\n", | |
" [\n", | |
" self.word_level_pooling(hidden_state, word_id, offset)\n", | |
" for hidden_state, word_id, offset in zip(hidden_states, word_ids, word_offsets)\n", | |
" ],\n", | |
" batch_first=True,\n", | |
" padding_value=0\n", | |
" )\n", | |
"\n", | |
" return pooled_output" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 108, | |
"id": "350119cf-b51d-46cc-bb32-0b7d4db13bad", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']\n", | |
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" | |
] | |
} | |
], | |
"source": [ | |
"model = Model(model_name)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 109, | |
"id": "b6ad5113-39d4-426f-95cc-a6694d05ff5d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | |
"To disable this warning, you can either:\n", | |
"\t- Avoid using `tokenizers` before the fork if possible\n", | |
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", | |
"word_id: tensor([-100, 0, 1, 2, 2, 3, 4, -100])\n", | |
"idx: tensor([0, 1, 2, 2, 3, 4])\n", | |
"tensor([0, 1, 2, 3, 4])\n", | |
"idx_unique_count: tensor([1, 1, 2, 1, 1])\n", | |
"hidden_state.shape: torch.Size([6, 768])\n", | |
"hidden_state_split: [torch.Size([1, 768]), torch.Size([1, 768]), torch.Size([2, 768]), torch.Size([1, 768]), torch.Size([1, 768])]\n", | |
"hidden_state_padded: [torch.Size([2, 768]), torch.Size([2, 768]), torch.Size([2, 768]), torch.Size([2, 768]), torch.Size([2, 768])]\n", | |
"output.shape: torch.Size([5, 768])\n", | |
"\n", | |
"word_id: tensor([-100, 0, 1, 1, 1, 2, -100, -100])\n", | |
"idx: tensor([0, 1, 1, 1, 2])\n", | |
"tensor([0, 1, 2])\n", | |
"idx_unique_count: tensor([1, 3, 1])\n", | |
"hidden_state.shape: torch.Size([5, 768])\n", | |
"hidden_state_split: [torch.Size([1, 768]), torch.Size([3, 768]), torch.Size([1, 768])]\n", | |
"hidden_state_padded: [torch.Size([3, 768]), torch.Size([3, 768]), torch.Size([3, 768])]\n", | |
"output.shape: torch.Size([3, 768])\n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Tensor" | |
] | |
}, | |
"execution_count": 109, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"for d in dl:\n", | |
" outputs = model(**d)\n", | |
" \n", | |
"type(outputs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 111, | |
"id": "04ba70e1-1d10-4c2b-a463-d8f50c11bcbe", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 5, 768])" | |
] | |
}, | |
"execution_count": 111, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"outputs.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "dc1d9368-4065-4821-b721-e78ab6690d84", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda env:ml]", | |
"language": "python", | |
"name": "conda-env-ml-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment