From c6d5e56595c0c0134c744f9e38ab2dedb6707388 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Tue, 9 Feb 2021 12:10:31 +0100 Subject: [PATCH] Fix naming (#10095) --- .../models/mobilebert/modeling_tf_mobilebert.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py index 537bc20632..ea44b6225b 100644 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py @@ -1105,7 +1105,6 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel _keys_to_ignore_on_load_unexpected = [ r"pooler", r"seq_relationship___cls", - r"predictions___cls", r"cls.seq_relationship", ] @@ -1113,10 +1112,10 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel super().__init__(config, *inputs, **kwargs) self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert") - self.mlm = TFMobileBertMLMHead(config, name="mlm___cls") + self.predictions = TFMobileBertMLMHead(config, name="predictions___cls") def get_lm_head(self): - return self.mlm.predictions + return self.predictions.predictions def get_prefix_bias_name(self): warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) @@ -1179,7 +1178,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel training=inputs["training"], ) sequence_output = outputs[0] - prediction_scores = self.mlm(sequence_output, training=inputs["training"]) + prediction_scores = self.predictions(sequence_output, training=inputs["training"]) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)