Created
May 28, 2021 08:13
-
-
Save kwang2049/1f0e1f0ce119456284c0af048ba097a7 to your computer and use it in GitHub Desktop.
DistilBERT modeling with LM head supported. One can download it and import modeling_distilbert to support DistilBERT for decoding usage, e.g. TSDAE: https://github.com/UKPLab/sentence-transformers/blob/master/examples/unsupervised_learning/TSDAE/train_tsdae_from_file.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" PyTorch DistilBERT model | |
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) | |
and in part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert) | |
""" | |
# Forked from the Huggingface's Transformers. Class DistilBertLMHeadModel is added for supporting training TSDAE. | |
import copy | |
import math | |
import warnings | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import CrossEntropyLoss | |
from transformers.activations import gelu | |
from transformers import DistilBertConfig | |
from transformers.modeling_outputs import ( | |
BaseModelOutput, | |
MaskedLMOutput, | |
MultipleChoiceModelOutput, | |
QuestionAnsweringModelOutput, | |
SequenceClassifierOutput, | |
TokenClassifierOutput, | |
CausalLMOutput | |
) | |
from transformers.modeling_utils import ( | |
PreTrainedModel, | |
apply_chunking_to_forward, | |
find_pruneable_heads_and_indices, | |
prune_linear_layer, | |
) | |
from transformers.utils import logging | |
from transformers import MODEL_MAPPING | |
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING | |
logger = logging.get_logger(__name__) | |
_CONFIG_FOR_DOC = "DistilBertConfig" | |
_TOKENIZER_FOR_DOC = "DistilBertTokenizer" | |
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ | |
"distilbert-base-uncased", | |
"distilbert-base-uncased-distilled-squad", | |
"distilbert-base-cased", | |
"distilbert-base-cased-distilled-squad", | |
"distilbert-base-german-cased", | |
"distilbert-base-multilingual-cased", | |
"distilbert-base-uncased-finetuned-sst-2-english", | |
# See all DistilBERT models at https://huggingface.co/models?filter=distilbert | |
] | |
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE # | |
def create_sinusoidal_embeddings(n_pos, dim, out): | |
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) | |
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) | |
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) | |
out.detach_() | |
out.requires_grad = False | |
class Embeddings(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id) | |
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim) | |
if config.sinusoidal_pos_embds: | |
create_sinusoidal_embeddings( | |
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight | |
) | |
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) | |
self.dropout = nn.Dropout(config.dropout) | |
def forward(self, input_ids): | |
""" | |
Parameters | |
---------- | |
input_ids: torch.tensor(bs, max_seq_length) | |
The token ids to embed. | |
Outputs | |
------- | |
embeddings: torch.tensor(bs, max_seq_length, dim) | |
The embedded tokens (plus position embeddings, no token_type embeddings) | |
""" | |
seq_length = input_ids.size(1) | |
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) | |
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) | |
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) | |
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) | |
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim) | |
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) | |
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) | |
return embeddings | |
class MultiHeadSelfAttention(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.n_heads = config.n_heads | |
self.dim = config.dim | |
self.dropout = nn.Dropout(p=config.attention_dropout) | |
assert self.dim % self.n_heads == 0 | |
self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim) | |
self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim) | |
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) | |
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) | |
self.pruned_heads = set() | |
def prune_heads(self, heads): | |
attention_head_size = self.dim // self.n_heads | |
if len(heads) == 0: | |
return | |
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads) | |
# Prune linear layers | |
self.q_lin = prune_linear_layer(self.q_lin, index) | |
self.k_lin = prune_linear_layer(self.k_lin, index) | |
self.v_lin = prune_linear_layer(self.v_lin, index) | |
self.out_lin = prune_linear_layer(self.out_lin, index, dim=1) | |
# Update hyper params | |
self.n_heads = self.n_heads - len(heads) | |
self.dim = attention_head_size * self.n_heads | |
self.pruned_heads = self.pruned_heads.union(heads) | |
def forward( | |
self, | |
query, | |
key, | |
value, | |
mask, | |
head_mask=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
output_attentions=False): | |
""" | |
Parameters | |
---------- | |
query: torch.tensor(bs, seq_length, dim) | |
key: torch.tensor(bs, seq_length, dim) | |
value: torch.tensor(bs, seq_length, dim) | |
mask: torch.tensor(bs, seq_length) | |
Outputs | |
------- | |
weights: torch.tensor(bs, n_heads, seq_length, seq_length) | |
Attention weights | |
context: torch.tensor(bs, seq_length, dim) | |
Contextualized layer. Optional: only if `output_attentions=True` | |
""" | |
bs, q_length, dim = query.size() | |
k_length = key.size(1) | |
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) | |
# assert key.size() == value.size() | |
dim_per_head = self.dim // self.n_heads | |
mask_reshp = (bs, 1, 1, k_length) | |
def shape(x): | |
""" separate heads """ | |
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) | |
def unshape(x): | |
""" group heads """ | |
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) | |
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) | |
if encoder_hidden_states is not None: | |
k = shape(self.k_lin(encoder_hidden_states)) # (bs, n_heads, k_length, dim_per_head) | |
v = shape(self.v_lin(encoder_hidden_states)) # (bs, n_heads, k_length, dim_per_head) | |
mask = encoder_attention_mask | |
else: | |
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) | |
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) | |
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) | |
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) | |
# mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) | |
# scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, q_length, k_length) | |
if mask is not None: | |
scores = scores + mask | |
weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length) | |
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) | |
# Mask heads if we want to | |
if head_mask is not None: | |
weights = weights * head_mask | |
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) | |
context = unshape(context) # (bs, q_length, dim) | |
context = self.out_lin(context) # (bs, q_length, dim) | |
if output_attentions: | |
return (context, weights) | |
else: | |
return (context,) | |
class FFN(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.dropout = nn.Dropout(p=config.dropout) | |
self.chunk_size_feed_forward = config.chunk_size_feed_forward | |
self.seq_len_dim = 1 | |
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) | |
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) | |
assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format( | |
config.activation | |
) | |
self.activation = gelu if config.activation == "gelu" else nn.ReLU() | |
def forward(self, input): | |
return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input) | |
def ff_chunk(self, input): | |
x = self.lin1(input) | |
x = self.activation(x) | |
x = self.lin2(x) | |
x = self.dropout(x) | |
return x | |
class TransformerBlock(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
assert config.dim % config.n_heads == 0 | |
self.attention = MultiHeadSelfAttention(config) | |
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) | |
self.is_decoder = config.is_decoder | |
self.add_cross_attention = config.add_cross_attention | |
if self.add_cross_attention: | |
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" | |
self.crossattention = MultiHeadSelfAttention(config) | |
self.crossattention_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) | |
self.ffn = FFN(config) | |
self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) | |
def forward( | |
self, | |
x, | |
attn_mask=None, | |
head_mask=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
output_attentions=False | |
): | |
""" | |
Parameters | |
---------- | |
x: torch.tensor(bs, seq_length, dim) | |
attn_mask: torch.tensor(bs, seq_length) | |
Outputs | |
------- | |
sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) | |
The attention weights | |
ffn_output: torch.tensor(bs, seq_length, dim) | |
The output of the transformer block contextualization. | |
""" | |
# Self-Attention | |
sa_output = self.attention( | |
query=x, | |
key=x, | |
value=x, | |
mask=attn_mask, | |
head_mask=head_mask, | |
output_attentions=output_attentions, | |
) | |
attention_output = sa_output[0] | |
attention_output = self.sa_layer_norm(attention_output + x) # (bs, seq_length, dim) | |
outputs = sa_output[1:] | |
# if output_attentions: | |
# sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) | |
# else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples | |
# assert type(sa_output) == tuple | |
# sa_output = sa_output[0] | |
# sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) | |
if self.is_decoder and encoder_hidden_states is not None: | |
assert hasattr( | |
self, "crossattention" | |
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" | |
cross_attention_outputs = self.crossattention( | |
attention_output, | |
attention_output, | |
attention_output, | |
attn_mask, | |
head_mask, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
output_attentions, | |
) | |
attention_output = self.sa_layer_norm(cross_attention_outputs[0] + attention_output) # (bs, seq_length, dim) | |
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights | |
# Feed Forward Network | |
ffn_output = self.ffn(attention_output) # (bs, seq_length, dim) | |
ffn_output = self.output_layer_norm(ffn_output + attention_output) # (bs, seq_length, dim) | |
outputs = (ffn_output,) + outputs | |
# output = (ffn_output,) | |
# if output_attentions: | |
# output = (sa_weights,) + output | |
return outputs | |
class Transformer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.n_layers = config.n_layers | |
layer = TransformerBlock(config) | |
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)]) | |
def forward( | |
self, | |
x, | |
attn_mask=None, | |
head_mask=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
output_attentions=False, | |
output_hidden_states=False, | |
return_dict=None | |
): | |
""" | |
Parameters | |
---------- | |
x: torch.tensor(bs, seq_length, dim) | |
Input sequence embedded. | |
attn_mask: torch.tensor(bs, seq_length) | |
Attention mask on the sequence. | |
Outputs | |
------- | |
hidden_state: torch.tensor(bs, seq_length, dim) | |
Sequence of hiddens states in the last (top) layer | |
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] | |
Tuple of length n_layers with the hidden states from each layer. | |
Optional: only if output_hidden_states=True | |
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] | |
Tuple of length n_layers with the attention weights from each layer | |
Optional: only if output_attentions=True | |
""" | |
all_hidden_states = () if output_hidden_states else None | |
all_attentions = () if output_attentions else None | |
hidden_state = x | |
for i, layer_module in enumerate(self.layer): | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_state,) | |
layer_outputs = layer_module( | |
x=hidden_state, | |
attn_mask=attn_mask, | |
head_mask=head_mask[i], | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
output_attentions=output_attentions | |
) | |
hidden_state = layer_outputs[-1] | |
if output_attentions: | |
assert len(layer_outputs) == 2 | |
attentions = layer_outputs[0] | |
all_attentions = all_attentions + (attentions,) | |
else: | |
assert len(layer_outputs) == 1 | |
# Add last layer | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_state,) | |
if not return_dict: | |
return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None) | |
return BaseModelOutput( | |
last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions | |
) | |
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL # | |
class DistilBertPreTrainedModel(PreTrainedModel): | |
"""An abstract class to handle weights initialization and | |
a simple interface for downloading and loading pretrained models. | |
""" | |
config_class = DistilBertConfig | |
load_tf_weights = None | |
base_model_prefix = "distilbert" | |
def _init_weights(self, module): | |
"""Initialize the weights.""" | |
if isinstance(module, nn.Embedding): | |
if module.weight.requires_grad: | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
module.bias.data.zero_() | |
DISTILBERT_START_DOCSTRING = r""" | |
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. | |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general | |
usage and behavior. | |
Parameters: | |
config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model. | |
Initializing with a config file does not load the weights associated with the model, only the configuration. | |
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. | |
""" | |
DISTILBERT_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using :class:`transformers.DistilBertTokenizer`. | |
See :func:`transformers.PreTrainedTokenizer.encode` and | |
:func:`transformers.PreTrainedTokenizer.__call__` for details. | |
`What are input IDs? <../glossary.html#input-ids>`__ | |
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
Mask to avoid performing attention on padding token indices. | |
Mask values selected in ``[0, 1]``: | |
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. | |
`What are attention masks? <../glossary.html#attention-mask>`__ | |
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |
Mask to nullify selected heads of the self-attention modules. | |
Mask values selected in ``[0, 1]``: | |
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**. | |
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |
This is useful if you want more control over how to convert `input_ids` indices into associated vectors | |
than the model's internal embedding lookup matrix. | |
output_attentions (:obj:`bool`, `optional`): | |
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. | |
output_hidden_states (:obj:`bool`, `optional`): | |
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail. | |
return_dict (:obj:`bool`, `optional`): | |
If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a | |
plain tuple. | |
""" | |
class DistilBertModel(DistilBertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.embeddings = Embeddings(config) # Embeddings | |
self.transformer = Transformer(config) # Encoder | |
self.init_weights() | |
def get_input_embeddings(self): | |
return self.embeddings.word_embeddings | |
def set_input_embeddings(self, new_embeddings): | |
self.embeddings.word_embeddings = new_embeddings | |
def _prune_heads(self, heads_to_prune): | |
"""Prunes heads of the model. | |
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} | |
See base class PreTrainedModel | |
""" | |
for layer, heads in heads_to_prune.items(): | |
self.transformer.layer[layer].attention.prune_heads(heads) | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
head_mask=None, | |
inputs_embeds=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
elif input_ids is not None: | |
input_shape = input_ids.size() | |
elif inputs_embeds is not None: | |
input_shape = inputs_embeds.size()[:-1] | |
else: | |
raise ValueError("You have to specify either input_ids or inputs_embeds") | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
if attention_mask is None: | |
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length) | |
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |
# ourselves in which case we just need to make it broadcastable to all heads. | |
# IMPORTANT!!! This prevents the model see tokens in the future | |
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) | |
# If a 2D or 3D attention mask is provided for the cross-attention | |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | |
if self.config.is_decoder and encoder_hidden_states is not None: | |
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() | |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) | |
if encoder_attention_mask is None: | |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) | |
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) | |
else: | |
encoder_extended_attention_mask = None | |
# Prepare head mask if needed | |
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) | |
if inputs_embeds is None: | |
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim) | |
return self.transformer( | |
x=inputs_embeds, | |
attn_mask=extended_attention_mask, | |
head_mask=head_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_extended_attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
MODEL_MAPPING[DistilBertConfig] = DistilBertModel | |
class DistilBertForMaskedLM(DistilBertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.distilbert = DistilBertModel(config) | |
self.vocab_transform = nn.Linear(config.dim, config.dim) | |
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) | |
self.vocab_projector = nn.Linear(config.dim, config.vocab_size) | |
self.init_weights() | |
self.mlm_loss_fct = nn.CrossEntropyLoss() | |
def get_output_embeddings(self): | |
return self.vocab_projector | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
**kwargs | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
Labels for computing the masked language modeling loss. | |
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) | |
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels | |
in ``[0, ..., config.vocab_size]`` | |
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): | |
Used to hide legacy arguments that have been deprecated. | |
""" | |
if "masked_lm_labels" in kwargs: | |
warnings.warn( | |
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", | |
FutureWarning, | |
) | |
labels = kwargs.pop("masked_lm_labels") | |
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
dlbrt_output = self.distilbert( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = dlbrt_output[0] # (bs, seq_length, dim) | |
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) | |
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim) | |
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) | |
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size) | |
mlm_loss = None | |
if labels is not None: | |
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1)) | |
if not return_dict: | |
output = (prediction_logits,) + dlbrt_output[1:] | |
return ((mlm_loss,) + output) if mlm_loss is not None else output | |
return MaskedLMOutput( | |
loss=mlm_loss, | |
logits=prediction_logits, | |
hidden_states=dlbrt_output.hidden_states, | |
attentions=dlbrt_output.attentions, | |
) | |
class DistilBertLMHeadModel(DistilBertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.distilbert = DistilBertModel(config) | |
self.vocab_transform = nn.Linear(config.dim, config.dim) | |
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) | |
self.vocab_projector = nn.Linear(config.dim, config.vocab_size) | |
self.init_weights() | |
self.mlm_loss_fct = nn.CrossEntropyLoss() | |
def get_output_embeddings(self): | |
return self.vocab_projector | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
head_mask=None, | |
inputs_embeds=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
**kwargs | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
Labels for computing the masked language modeling loss. | |
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) | |
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels | |
in ``[0, ..., config.vocab_size]`` | |
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): | |
Used to hide legacy arguments that have been deprecated. | |
""" | |
if "masked_lm_labels" in kwargs: | |
warnings.warn( | |
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", | |
FutureWarning, | |
) | |
labels = kwargs.pop("masked_lm_labels") | |
# assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
dlbrt_output = self.distilbert( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = dlbrt_output[0] # (bs, seq_length, dim) | |
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) | |
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim) | |
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) | |
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size) | |
lm_loss = None | |
if labels is not None: | |
# we are doing next-token prediction; shift prediction scores and input ids by one | |
shifted_prediction_scores = prediction_logits[:, :-1, :].contiguous() | |
labels = labels[:, 1:].contiguous() | |
loss_fct = CrossEntropyLoss() | |
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) | |
if not return_dict: | |
output = (prediction_logits,) + dlbrt_output[1:] | |
return ((lm_loss,) + output) if lm_loss is not None else output | |
return CausalLMOutput( | |
loss=lm_loss, | |
logits=prediction_logits, | |
hidden_states=dlbrt_output.hidden_states, | |
attentions=dlbrt_output.attentions, | |
) | |
MODEL_FOR_CAUSAL_LM_MAPPING[DistilBertConfig] = DistilBertLMHeadModel | |
class DistilBertForSequenceClassification(DistilBertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.distilbert = DistilBertModel(config) | |
self.pre_classifier = nn.Linear(config.dim, config.dim) | |
self.classifier = nn.Linear(config.dim, config.num_labels) | |
self.dropout = nn.Dropout(config.seq_classif_dropout) | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for computing the sequence classification/regression loss. | |
Indices should be in :obj:`[0, ..., config.num_labels - 1]`. | |
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
distilbert_output = self.distilbert( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_state = distilbert_output[0] # (bs, seq_len, dim) | |
pooled_output = hidden_state[:, 0] # (bs, dim) | |
pooled_output = self.pre_classifier(pooled_output) # (bs, dim) | |
pooled_output = nn.ReLU()(pooled_output) # (bs, dim) | |
pooled_output = self.dropout(pooled_output) # (bs, dim) | |
logits = self.classifier(pooled_output) # (bs, dim) | |
loss = None | |
if labels is not None: | |
if self.num_labels == 1: | |
loss_fct = nn.MSELoss() | |
loss = loss_fct(logits.view(-1), labels.view(-1)) | |
else: | |
loss_fct = nn.CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
if not return_dict: | |
output = (logits,) + distilbert_output[1:] | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=distilbert_output.hidden_states, | |
attentions=distilbert_output.attentions, | |
) | |
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.distilbert = DistilBertModel(config) | |
self.qa_outputs = nn.Linear(config.dim, config.num_labels) | |
assert config.num_labels == 2 | |
self.dropout = nn.Dropout(config.qa_dropout) | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
head_mask=None, | |
inputs_embeds=None, | |
start_positions=None, | |
end_positions=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
r""" | |
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for position (index) of the start of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (`sequence_length`). | |
Position outside of the sequence are not taken into account for computing the loss. | |
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for position (index) of the end of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (`sequence_length`). | |
Position outside of the sequence are not taken into account for computing the loss. | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
distilbert_output = self.distilbert( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = distilbert_output[0] # (bs, max_query_len, dim) | |
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim) | |
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2) | |
start_logits, end_logits = logits.split(1, dim=-1) | |
start_logits = start_logits.squeeze(-1) # (bs, max_query_len) | |
end_logits = end_logits.squeeze(-1) # (bs, max_query_len) | |
total_loss = None | |
if start_positions is not None and end_positions is not None: | |
# If we are on multi-GPU, split add a dimension | |
if len(start_positions.size()) > 1: | |
start_positions = start_positions.squeeze(-1) | |
if len(end_positions.size()) > 1: | |
end_positions = end_positions.squeeze(-1) | |
# sometimes the start/end positions are outside our model inputs, we ignore these terms | |
ignored_index = start_logits.size(1) | |
start_positions.clamp_(0, ignored_index) | |
end_positions.clamp_(0, ignored_index) | |
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) | |
start_loss = loss_fct(start_logits, start_positions) | |
end_loss = loss_fct(end_logits, end_positions) | |
total_loss = (start_loss + end_loss) / 2 | |
if not return_dict: | |
output = (start_logits, end_logits) + distilbert_output[1:] | |
return ((total_loss,) + output) if total_loss is not None else output | |
return QuestionAnsweringModelOutput( | |
loss=total_loss, | |
start_logits=start_logits, | |
end_logits=end_logits, | |
hidden_states=distilbert_output.hidden_states, | |
attentions=distilbert_output.attentions, | |
) | |
class DistilBertForTokenClassification(DistilBertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.distilbert = DistilBertModel(config) | |
self.dropout = nn.Dropout(config.dropout) | |
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
Labels for computing the token classification loss. | |
Indices should be in ``[0, ..., config.num_labels - 1]``. | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.distilbert( | |
input_ids, | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
sequence_output = self.dropout(sequence_output) | |
logits = self.classifier(sequence_output) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
# Only keep active parts of the loss | |
if attention_mask is not None: | |
active_loss = attention_mask.view(-1) == 1 | |
active_logits = logits.view(-1, self.num_labels) | |
active_labels = torch.where( | |
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) | |
) | |
loss = loss_fct(active_logits, active_labels) | |
else: | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return TokenClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class DistilBertForMultipleChoice(DistilBertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.distilbert = DistilBertModel(config) | |
self.pre_classifier = nn.Linear(config.dim, config.dim) | |
self.classifier = nn.Linear(config.dim, 1) | |
self.dropout = nn.Dropout(config.seq_classif_dropout) | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for computing the multiple choice classification loss. | |
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension | |
of the input tensors. (see `input_ids` above) | |
Returns: | |
Examples:: | |
>>> from transformers import DistilBertTokenizer, DistilBertForMultipleChoice | |
>>> import torch | |
>>> tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') | |
>>> model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased', return_dict=True) | |
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." | |
>>> choice0 = "It is eaten with a fork and a knife." | |
>>> choice1 = "It is eaten while held in the hand." | |
>>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 | |
>>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors='pt', padding=True) | |
>>> outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1 | |
>>> # the linear classifier still needs to be trained | |
>>> loss = outputs.loss | |
>>> logits = outputs.logits | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] | |
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None | |
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None | |
inputs_embeds = ( | |
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) | |
if inputs_embeds is not None | |
else None | |
) | |
outputs = self.distilbert( | |
input_ids, | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_state = outputs[0] # (bs * num_choices, seq_len, dim) | |
pooled_output = hidden_state[:, 0] # (bs * num_choices, dim) | |
pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim) | |
pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim) | |
pooled_output = self.dropout(pooled_output) # (bs * num_choices, dim) | |
logits = self.classifier(pooled_output) # (bs * num_choices, 1) | |
reshaped_logits = logits.view(-1, num_choices) # (bs, num_choices) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(reshaped_logits, labels) | |
if not return_dict: | |
output = (reshaped_logits,) + outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return MultipleChoiceModelOutput( | |
loss=loss, | |
logits=reshaped_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment