Tf model outputs (#6247)
* TF outputs and test on BERT * Albert to DistilBert * All remaining TF models except T5 * Documentation * One file forgotten * TF outputs and test on BERT * Albert to DistilBert * All remaining TF models except T5 * Documentation * One file forgotten * Add new models and fix issues * Quality improvements * Add T5 * A bit of cleanup * Fix for slow tests * Style
This commit is contained in:
@@ -685,9 +685,9 @@ class MobileBertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
|
||||
@dataclass
|
||||
class MobileBertForPretrainingOutput(ModelOutput):
|
||||
class MobileBertForPreTrainingOutput(ModelOutput):
|
||||
"""
|
||||
Output type of :class:`~transformers.MobileBertForPretrainingModel`.
|
||||
Output type of :class:`~transformers.MobileBertForPreTrainingModel`.
|
||||
|
||||
Args:
|
||||
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||
@@ -948,7 +948,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=MobileBertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -1018,7 +1018,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return MobileBertForPretrainingOutput(
|
||||
return MobileBertForPreTrainingOutput(
|
||||
loss=total_loss,
|
||||
prediction_logits=prediction_scores,
|
||||
seq_relationship_logits=seq_relationship_score,
|
||||
|
||||
Reference in New Issue
Block a user