Add next sentence prediction loss computation (#8462)
* Add next sentence prediction loss computation * Apply style * Fix tests * Add forgotten import * Add forgotten import * Use a new parameter * Remove kwargs and use positional arguments
This commit is contained in:
@@ -307,6 +307,8 @@ class TFNextSentencePredictorOutput(ModelOutput):
|
||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||
|
||||
Args:
|
||||
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
|
||||
Next sentence prediction loss.
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
|
||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
||||
before SoftMax).
|
||||
@@ -323,6 +325,7 @@ class TFNextSentencePredictorOutput(ModelOutput):
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: tf.Tensor = None
|
||||
logits: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
Reference in New Issue
Block a user