Created
July 21, 2021 14:57
-
-
Save gautierdag/7e0ca9e213cfa09b63929e00631363c8 to your computer and use it in GitHub Desktop.
RuntimeError: output with shape [1] doesn't match the broadcast shape [1024, 1024]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from transformers import AutoModel, AutoTokenizer | |
from transformers.optimization import get_cosine_schedule_with_warmup | |
import pytorch_lightning as pl | |
# from deepspeed.ops.adam import FusedAdam - get different error (expected tensor on cuda but got cpu) with FusedAdam | |
class BoringDataset(torch.utils.data.Dataset): | |
def __init__(self, len=1000): | |
self.len = len | |
def __len__(self): | |
return self.len | |
def __getitem__(self, idx): | |
sample = { | |
"text_input": "This is some fake and boring text.", | |
"target": 0, | |
} | |
return sample | |
class BoringBertClassifierModel(pl.LightningModule): | |
def __init__(self, bert_model="microsoft/deberta-v2-xlarge"): | |
super(BoringBertClassifierModel, self).__init__() | |
self.model_type = "BoringBertClassifierModel" | |
# Load Text Model | |
self.text_model = AutoModel.from_pretrained(bert_model) | |
self.output_layer = nn.Linear(self.text_model.config.hidden_size, 1) | |
self.criterion = nn.MSELoss(reduction="sum") | |
def forward(self, text_input, **kwargs): | |
outputs = self.text_model(**text_input)[0] | |
x = torch.sum(outputs, dim=1) # aggregate over sequence | |
predictions = self.output_layer(x) | |
return predictions.squeeze(1) | |
def training_step(self, batch, batch_nb): | |
predicted_targets = self(**batch) | |
target_loss = self.criterion(predicted_targets, batch["target"]) | |
return target_loss | |
def validation_step(self, val_batch, val_batch_idx, **kwargs): | |
predicted_targets = self(**val_batch) | |
target_loss = self.criterion(predicted_targets, val_batch["target"]) | |
return target_loss | |
def configure_optimizers(self): | |
optimizer = torch.optim.AdamW(self.parameters(), lr=0.00001, weight_decay=0.01) | |
schedule = get_cosine_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=100, | |
num_training_steps=1000, | |
) | |
scheduler = { | |
"scheduler": schedule, | |
"interval": "step", # runs per batch rather than per epoch | |
"frequency": 1, | |
"name": "learning_rate", # uncomment if using LearningRateMonitor | |
} | |
return [optimizer], [scheduler] | |
if __name__ == "__main__": | |
bert_model = "microsoft/deberta-large" # runs fine without deepspeed | |
tokenizer = AutoTokenizer.from_pretrained(bert_model, model_max_length=100) | |
def collate_fn(batch): | |
items = {} | |
items["text_input"] = tokenizer( | |
[batch_item["text_input"] for batch_item in batch], | |
padding=True, | |
return_tensors="pt", | |
truncation=True, | |
) | |
items["target"] = torch.tensor( | |
[batch_item["target"] for batch_item in batch] | |
).float() | |
return items | |
train_dataset = BoringDataset(len=100) | |
val_dataset = BoringDataset(len=20) | |
train_loader = DataLoader( | |
dataset=train_dataset, | |
shuffle=True, | |
collate_fn=collate_fn, | |
batch_size=2, | |
num_workers=1, | |
) | |
val_loader = DataLoader( | |
dataset=val_dataset, | |
batch_size=4, | |
shuffle=False, | |
collate_fn=collate_fn, | |
num_workers=1, | |
) | |
model = BoringBertClassifierModel(bert_model=bert_model) | |
trainer = pl.Trainer( | |
gpus=1, | |
max_epochs=10, | |
progress_bar_refresh_rate=1, | |
log_every_n_steps=1, | |
plugins="deepspeed_stage_3_offload", | |
precision=16, | |
) | |
# Train the model ⚡ | |
trainer.fit( | |
model, | |
train_dataloader=train_loader, | |
val_dataloaders=[val_loader], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment