Tf model outputs (#6247)

* TF outputs and test on BERT

* Albert to DistilBert

* All remaining TF models except T5

* Documentation

* One file forgotten

* TF outputs and test on BERT

* Albert to DistilBert

* All remaining TF models except T5

* Documentation

* One file forgotten

* Add new models and fix issues

* Quality improvements

* Add T5

* A bit of cleanup

* Fix for slow tests

* Style
This commit is contained in:
Sylvain Gugger
2020-08-05 11:34:39 -04:00
committed by GitHub
parent bd0eab351a
commit c67d1a0259
51 changed files with 3253 additions and 2430 deletions

View File

@@ -685,9 +685,9 @@ class MobileBertPreTrainedModel(PreTrainedModel):
@dataclass
class MobileBertForPretrainingOutput(ModelOutput):
class MobileBertForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.MobileBertForPretrainingModel`.
Output type of :class:`~transformers.MobileBertForPreTrainingModel`.
Args:
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
@@ -948,7 +948,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MobileBertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@@ -1018,7 +1018,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return MobileBertForPretrainingOutput(
return MobileBertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,