fix SqueezeBertForMaskedLM (#8479)
This commit is contained in:
@@ -22,7 +22,7 @@ The authors found that SqueezeBERT is 4.3x faster than `bert-base-uncased` on a
|
|||||||
The model is pretrained using the Masked Language Model (MLM) and Sentence Order Prediction (SOP) tasks.
|
The model is pretrained using the Masked Language Model (MLM) and Sentence Order Prediction (SOP) tasks.
|
||||||
(Author's note: If you decide to pretrain your own model, and you prefer to train with MLM only, that should work too.)
|
(Author's note: If you decide to pretrain your own model, and you prefer to train with MLM only, that should work too.)
|
||||||
|
|
||||||
The SqueezeBERT paper presents 2 approaches to finetuning the model:
|
From the SqueezeBERT paper:
|
||||||
> We pretrain SqueezeBERT from scratch (without distillation) using the [LAMB](https://arxiv.org/abs/1904.00962) optimizer, and we employ the hyperparameters recommended by the LAMB authors: a global batch size of 8192, a learning rate of 2.5e-3, and a warmup proportion of 0.28. Following the LAMB paper's recommendations, we pretrain for 56k steps with a maximum sequence length of 128 and then for 6k steps with a maximum sequence length of 512.
|
> We pretrain SqueezeBERT from scratch (without distillation) using the [LAMB](https://arxiv.org/abs/1904.00962) optimizer, and we employ the hyperparameters recommended by the LAMB authors: a global batch size of 8192, a learning rate of 2.5e-3, and a warmup proportion of 0.28. Following the LAMB paper's recommendations, we pretrain for 56k steps with a maximum sequence length of 128 and then for 6k steps with a maximum sequence length of 512.
|
||||||
|
|
||||||
## Finetuning
|
## Finetuning
|
||||||
|
|||||||
@@ -373,6 +373,53 @@ class SqueezeBertPooler(nn.Module):
|
|||||||
return pooled_output
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeBertPredictionHeadTransform(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
if isinstance(config.hidden_act, str):
|
||||||
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.transform_act_fn = config.hidden_act
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeBertLMPredictionHead(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.transform = SqueezeBertPredictionHeadTransform(config)
|
||||||
|
|
||||||
|
# The output weights are the same as the input embeddings, but there is
|
||||||
|
# an output-only bias for each token.
|
||||||
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
|
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||||
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.transform(hidden_states)
|
||||||
|
hidden_states = self.decoder(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeBertOnlyMLMHead(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.predictions = SqueezeBertLMPredictionHead(config)
|
||||||
|
|
||||||
|
def forward(self, sequence_output):
|
||||||
|
prediction_scores = self.predictions(sequence_output)
|
||||||
|
return prediction_scores
|
||||||
|
|
||||||
|
|
||||||
class SqueezeBertPreTrainedModel(PreTrainedModel):
|
class SqueezeBertPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@@ -594,16 +641,19 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
|
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
|
||||||
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
||||||
|
|
||||||
|
authorized_missing_keys = [r"predictions.decoder.bias"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.transformer = SqueezeBertModel(config)
|
self.transformer = SqueezeBertModel(config)
|
||||||
self.lm_head = nn.Linear(config.embedding_size, config.vocab_size)
|
self.cls = SqueezeBertOnlyMLMHead(config)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
@@ -646,7 +696,7 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.lm_head(sequence_output)
|
prediction_scores = self.cls(sequence_output)
|
||||||
|
|
||||||
masked_lm_loss = None
|
masked_lm_loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user