Last active
June 3, 2022 13:54
-
-
Save macleginn/561e4d6ed7928b24e8626f8d520cc963 to your computer and use it in GitHub Desktop.
Training and evaluation code for a simple model that predicts a token removed from a sentence
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from math import ceil | |
from random import shuffle | |
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModel | |
from transformers import AdamW, get_scheduler | |
from tqdm.auto import tqdm | |
class ClassificationHead(nn.Module): | |
def __init__(self, n_classes, input_size=768): | |
super().__init__() | |
self.linear1 = nn.Linear(input_size, 768) | |
self.linear2 = nn.Linear(768, n_classes) | |
def forward(self, x): | |
x = self.linear1(x) | |
x = torch.tanh(x) | |
return self.linear2(x) | |
class ClassificationHeadSimple(nn.Module): | |
''' | |
A simple linear classifier that converts token embeddings | |
to class scores. | |
''' | |
def __init__(self, n_classes, input_size=768): | |
super().__init__() | |
self.linear = nn.Linear(input_size, n_classes) | |
def forward(self, x): | |
return self.linear(x) | |
def throw_away_token(sentence_batch, tokeniser): | |
throwaway_token_ids = [None for _ in sentence_batch] | |
subword_dict = tokeniser(sentence_batch, return_tensors='pt', padding=True, truncation=True) | |
for i in range(len(sentence_batch)): | |
n_tokens = subword_dict['attention_mask'][i].sum().item() | |
throwaway_idx = torch.randint(low=1, high=n_tokens, size=(1,1)).item() | |
throwaway_token_ids[i] = subword_dict['input_ids'][i][throwaway_idx] | |
zero_tensor = torch.tensor([0]) | |
subword_dict['input_ids'][i] = torch.cat(( | |
subword_dict['input_ids'][i][:throwaway_idx], | |
subword_dict['input_ids'][i][throwaway_idx+1:], | |
zero_tensor)) | |
subword_dict['token_type_ids'][i] = torch.cat(( | |
subword_dict['token_type_ids'][i][:throwaway_idx], | |
subword_dict['token_type_ids'][i][throwaway_idx+1:], | |
zero_tensor)) | |
subword_dict['attention_mask'][i] = torch.cat(( | |
subword_dict['attention_mask'][i][:throwaway_idx], | |
subword_dict['attention_mask'][i][throwaway_idx+1:], | |
zero_tensor)) | |
return throwaway_token_ids, subword_dict | |
if __name__ == '__main__': | |
model_name = "bert-base-uncased" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
bert_model = AutoModel.from_pretrained(model_name) | |
bert_model.cuda() | |
bert_model = nn.DataParallel(bert_model) | |
n_classes = tokenizer.vocab_size | |
classification_head = ClassificationHeadSimple(n_classes) | |
classification_head.cuda() | |
classification_head = nn.DataParallel(classification_head) | |
optimizer = AdamW(list(classification_head.parameters()), lr=1e-5) | |
with open('../data/hansard_short_sentences.json', 'r', encoding='utf-8') as inp: | |
data_all = json.load(inp) | |
indices_permuted = torch.randperm(len(data_all)) | |
data_dev = [data_all[i].lower() for i in indices_permuted[:1000]] | |
data_test = [data_all[i].lower() for i in indices_permuted[1000:2000]] | |
data_train = [data_all[i].lower() for i in indices_permuted[2000:22000]] | |
n_epochs = 5 | |
n_training_steps = n_epochs * len(data_train) | |
lr_scheduler = get_scheduler( | |
'linear', | |
optimizer=optimizer, | |
num_warmup_steps=0, | |
num_training_steps=n_training_steps | |
) | |
loss_function = nn.CrossEntropyLoss() | |
batch_size = 128 | |
min_dev_loss = float('inf') | |
for epoch in range(n_epochs): | |
# Train | |
bert_model.train() | |
epoch_train_losses = [] | |
n_steps_train = ceil(len(data_train) / batch_size) | |
for batch_n in tqdm(range(n_steps_train), desc=f'Epoch {epoch+1}, train', leave=False): | |
batch_sentences = data_train[batch_size * batch_n: | |
batch_size * (batch_n + 1)] | |
gold_labels, inputs = throw_away_token(batch_sentences, tokenizer) | |
mbert_inputs = { | |
k: v.cuda() for k, v in inputs.items() | |
} | |
with torch.no_grad(): | |
mbert_outputs = bert_model(**mbert_inputs).last_hidden_state | |
cls_embeddings = [] | |
for i in range(len(batch_sentences)): | |
cls_embeddings.append(mbert_outputs[i, 0, :]) | |
logits = classification_head(torch.vstack(cls_embeddings)) | |
loss = loss_function(logits, torch.tensor(gold_labels).cuda()) | |
loss.backward() | |
epoch_train_losses.append(loss.item()) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
print( | |
f'Epoch {epoch+1} train loss: {torch.tensor(epoch_train_losses).mean()}') | |
# Evaluate | |
bert_model.eval() | |
epoch_dev_losses = [] | |
n_steps_dev = ceil(len(data_dev) / batch_size) | |
hits = 0 | |
misses = 0 | |
for batch_n in tqdm(range(n_steps_dev), desc=f'Epoch {epoch+1}, dev', leave=False): | |
batch_sentences = data_train[batch_size * batch_n: | |
batch_size * (batch_n + 1)] | |
gold_labels, inputs = throw_away_token(batch_sentences, tokenizer) | |
mbert_inputs = { | |
k: v.cuda() for k, v in inputs.items() | |
} | |
with torch.no_grad(): | |
mbert_outputs = bert_model(**mbert_inputs).last_hidden_state | |
cls_embeddings = [] | |
for i in range(len(batch_sentences)): | |
cls_embeddings.append(mbert_outputs[i, 0, :]) | |
logits = classification_head(torch.vstack(cls_embeddings)) | |
loss = loss_function(logits, torch.tensor(gold_labels).cuda()) | |
epoch_dev_losses.append(loss.item()) | |
classes = torch.argmax(logits, dim=1) | |
for guessed, gold in zip(classes, gold_labels): | |
if guessed == gold: | |
hits += 1 | |
else: | |
misses += 1 | |
print( | |
f'Epoch {epoch+1} dev loss: {torch.tensor(epoch_dev_losses).mean()}') | |
print( | |
f'Epoch {epoch+1} dev accuracy: {round(hits / (hits + misses), 2)}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment