Fix naming (#10095)
This commit is contained in:
@@ -1105,7 +1105,6 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
|||||||
_keys_to_ignore_on_load_unexpected = [
|
_keys_to_ignore_on_load_unexpected = [
|
||||||
r"pooler",
|
r"pooler",
|
||||||
r"seq_relationship___cls",
|
r"seq_relationship___cls",
|
||||||
r"predictions___cls",
|
|
||||||
r"cls.seq_relationship",
|
r"cls.seq_relationship",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1113,10 +1112,10 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
|||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
|
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
|
||||||
self.mlm = TFMobileBertMLMHead(config, name="mlm___cls")
|
self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
|
||||||
|
|
||||||
def get_lm_head(self):
|
def get_lm_head(self):
|
||||||
return self.mlm.predictions
|
return self.predictions.predictions
|
||||||
|
|
||||||
def get_prefix_bias_name(self):
|
def get_prefix_bias_name(self):
|
||||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||||
@@ -1179,7 +1178,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
|||||||
training=inputs["training"],
|
training=inputs["training"],
|
||||||
)
|
)
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
|
prediction_scores = self.predictions(sequence_output, training=inputs["training"])
|
||||||
|
|
||||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user