Created
July 5, 2021 10:24
-
-
Save gautierdag/3bd64f33470cb11f4323ce7fa86524a9 to your computer and use it in GitHub Desktop.
Pytorch Bert Layer-wise Learning Rate Decay
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.optim import AdamW | |
from transformers import AutoModel | |
def get_bert_layerwise_lr_groups(bert_model, learning_rate=1e-5, layer_decay=0.9): | |
""" | |
Gets parameter groups with decayed learning rate based on depth in network | |
Layers closer to output will have higher learning rate | |
Args: | |
bert_model: A huggingface bert-like model (should have embedding layer and encoder) | |
learning_rate: The learning rate at the output layer | |
layer_decay: How much to decay the learning rate per depth (recommended 0.9-0.95) | |
Returns: | |
grouped_parameters (list): list of parameters with their decayed learning rates | |
""" | |
n_layers = len(bert_model.encoder.layer) + 1 # + 1 (embedding) | |
embedding_decayed_lr = learning_rate * (layer_decay ** (n_layers+1)) | |
grouped_parameters = [{"params": bert_model.embeddings.parameters(), 'lr': embedding_decayed_lr}] | |
for depth in range(1, n_layers): | |
decayed_lr = learning_rate * (layer_decay ** (n_layers + 1 - depth)) | |
grouped_parameters.append( | |
{"params": bert_model.encoder.layer[depth-1].parameters(), 'lr': decayed_lr} | |
) | |
return grouped_parameters | |
# Example: | |
model = AutoModel.from_pretrained("roberta-base") | |
lr_groups = get_bert_layerwise_lr_groups(model, learning_rate=1e-5) | |
optimizer = torch.optim.AdamW( | |
lr_groups, lr=1e-5, weight_decay=0 | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment