fix SqueezeBertForMaskedLM (#8479)

This commit is contained in:
Forrest Iandola
2020-11-12 09:19:37 -08:00
committed by GitHub
parent 7933054638
commit 0fa0349883
2 changed files with 54 additions and 4 deletions

View File

@@ -22,7 +22,7 @@ The authors found that SqueezeBERT is 4.3x faster than `bert-base-uncased` on a
The model is pretrained using the Masked Language Model (MLM) and Sentence Order Prediction (SOP) tasks. The model is pretrained using the Masked Language Model (MLM) and Sentence Order Prediction (SOP) tasks.
(Author's note: If you decide to pretrain your own model, and you prefer to train with MLM only, that should work too.) (Author's note: If you decide to pretrain your own model, and you prefer to train with MLM only, that should work too.)
The SqueezeBERT paper presents 2 approaches to finetuning the model: From the SqueezeBERT paper:
> We pretrain SqueezeBERT from scratch (without distillation) using the [LAMB](https://arxiv.org/abs/1904.00962) optimizer, and we employ the hyperparameters recommended by the LAMB authors: a global batch size of 8192, a learning rate of 2.5e-3, and a warmup proportion of 0.28. Following the LAMB paper's recommendations, we pretrain for 56k steps with a maximum sequence length of 128 and then for 6k steps with a maximum sequence length of 512. > We pretrain SqueezeBERT from scratch (without distillation) using the [LAMB](https://arxiv.org/abs/1904.00962) optimizer, and we employ the hyperparameters recommended by the LAMB authors: a global batch size of 8192, a learning rate of 2.5e-3, and a warmup proportion of 0.28. Following the LAMB paper's recommendations, we pretrain for 56k steps with a maximum sequence length of 128 and then for 6k steps with a maximum sequence length of 512.
## Finetuning ## Finetuning

View File

@@ -373,6 +373,53 @@ class SqueezeBertPooler(nn.Module):
return pooled_output return pooled_output
class SqueezeBertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class SqueezeBertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = SqueezeBertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class SqueezeBertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = SqueezeBertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class SqueezeBertPreTrainedModel(PreTrainedModel): class SqueezeBertPreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -594,16 +641,19 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING) @add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
authorized_missing_keys = [r"predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.transformer = SqueezeBertModel(config) self.transformer = SqueezeBertModel(config)
self.lm_head = nn.Linear(config.embedding_size, config.vocab_size) self.cls = SqueezeBertOnlyMLMHead(config)
self.init_weights() self.init_weights()
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.cls.predictions.decoder
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
@@ -646,7 +696,7 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) prediction_scores = self.cls(sequence_output)
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None: