Add next sentence prediction loss computation (#8462)

* Add next sentence prediction loss computation

* Apply style

* Fix tests

* Add forgotten import

* Add forgotten import

* Use a new parameter

* Remove kwargs and use positional arguments
This commit is contained in:
Julien Plu
2020-11-11 15:02:06 +01:00
committed by GitHub
parent 23290836c3
commit da842e4e72
5 changed files with 121 additions and 14 deletions

View File

@@ -307,6 +307,8 @@ class TFNextSentencePredictorOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
Next sentence prediction loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
@@ -323,6 +325,7 @@ class TFNextSentencePredictorOutput(ModelOutput):
heads.
"""
loss: tf.Tensor = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None