fix SqueezeBertForMaskedLM (#8479)
This commit is contained in:
@@ -22,7 +22,7 @@ The authors found that SqueezeBERT is 4.3x faster than `bert-base-uncased` on a
|
||||
The model is pretrained using the Masked Language Model (MLM) and Sentence Order Prediction (SOP) tasks.
|
||||
(Author's note: If you decide to pretrain your own model, and you prefer to train with MLM only, that should work too.)
|
||||
|
||||
The SqueezeBERT paper presents 2 approaches to finetuning the model:
|
||||
From the SqueezeBERT paper:
|
||||
> We pretrain SqueezeBERT from scratch (without distillation) using the [LAMB](https://arxiv.org/abs/1904.00962) optimizer, and we employ the hyperparameters recommended by the LAMB authors: a global batch size of 8192, a learning rate of 2.5e-3, and a warmup proportion of 0.28. Following the LAMB paper's recommendations, we pretrain for 56k steps with a maximum sequence length of 128 and then for 6k steps with a maximum sequence length of 512.
|
||||
|
||||
## Finetuning
|
||||
|
||||
@@ -373,6 +373,53 @@ class SqueezeBertPooler(nn.Module):
|
||||
return pooled_output
|
||||
|
||||
|
||||
class SqueezeBertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.transform_act_fn = config.hidden_act
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SqueezeBertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.transform = SqueezeBertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SqueezeBertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.predictions = SqueezeBertLMPredictionHead(config)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class SqueezeBertPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@@ -594,16 +641,19 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
|
||||
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
||||
|
||||
authorized_missing_keys = [r"predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.transformer = SqueezeBertModel(config)
|
||||
self.lm_head = nn.Linear(config.embedding_size, config.vocab_size)
|
||||
self.cls = SqueezeBertOnlyMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
@@ -646,7 +696,7 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
|
||||
Reference in New Issue
Block a user