ALBERT can load pre-trained models. Doesn't inherit from BERT anymore.

This commit is contained in:
Lysandre
2019-10-31 16:37:34 +00:00
committed by Lysandre Debut
parent c4403006b8
commit 4f3a54bfc8
4 changed files with 68 additions and 12 deletions

View File

@@ -21,6 +21,7 @@ import logging
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_albert import AlbertConfig
from transformers.modeling_bert import BertEmbeddings, BertPreTrainedModel, BertModel, BertSelfAttention, prune_linear_layer, ACT2FN
from .file_utils import add_start_docstrings
@@ -274,6 +275,29 @@ class AlbertTransformer(nn.Module):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class AlbertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = AlbertConfig
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "albert"
def _init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear)) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
ALBERT_START_DOCSTRING = r""" The ALBERT model was proposed in
`ALBERT: A Lite BERT for Self-supervised Learning of Language Representations`_
by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. It presents
@@ -338,7 +362,7 @@ ALBERT_INPUTS_DOCSTRING = r"""
@add_start_docstrings("The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
class AlbertModel(BertModel):
class AlbertModel(AlbertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
@@ -358,6 +382,12 @@ class AlbertModel(BertModel):
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
"""
config_class = AlbertConfig
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert"
def __init__(self, config):
super(AlbertModel, self).__init__(config)
@@ -369,6 +399,11 @@ class AlbertModel(BertModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None:
@@ -423,7 +458,7 @@ class AlbertMLMHead(nn.Module):
@add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
class AlbertForMaskedLM(BertPreTrainedModel):
class AlbertForMaskedLM(AlbertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
@@ -445,11 +480,6 @@ class AlbertForMaskedLM(BertPreTrainedModel):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
"""
config_class = AlbertConfig
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert"
def __init__(self, config):
super(AlbertForMaskedLM, self).__init__(config)