Add next sentence prediction loss computation (#8462)
* Add next sentence prediction loss computation * Apply style * Fix tests * Add forgotten import * Add forgotten import * Use a new parameter * Remove kwargs and use positional arguments
This commit is contained in:
@@ -46,6 +46,7 @@ from .modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFMaskedLanguageModelingLoss,
|
||||
TFMultipleChoiceLoss,
|
||||
TFNextSentencePredictionLoss,
|
||||
TFPreTrainedModel,
|
||||
TFQuestionAnsweringLoss,
|
||||
TFSequenceClassificationLoss,
|
||||
@@ -1036,7 +1037,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
||||
class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
@@ -1045,7 +1046,20 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
next_sentence_label=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
|
||||
@@ -1064,17 +1078,43 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
||||
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
|
||||
>>> assert logits[0][0] < logits[0][1] # the next sentence was random
|
||||
"""
|
||||
return_dict = kwargs.get("return_dict")
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
outputs = self.bert(inputs, **kwargs)
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
|
||||
|
||||
outputs = self.bert(
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
seq_relationship_score = self.nsp(pooled_output)
|
||||
seq_relationship_scores = self.nsp(pooled_output)
|
||||
|
||||
next_sentence_loss = (
|
||||
None
|
||||
if next_sentence_label is None
|
||||
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (seq_relationship_score,) + outputs[2:]
|
||||
output = (seq_relationship_scores,) + outputs[2:]
|
||||
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
||||
|
||||
return TFNextSentencePredictorOutput(
|
||||
logits=seq_relationship_score,
|
||||
loss=next_sentence_loss,
|
||||
logits=seq_relationship_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -44,6 +44,7 @@ from .modeling_tf_outputs import (
|
||||
from .modeling_tf_utils import (
|
||||
TFMaskedLanguageModelingLoss,
|
||||
TFMultipleChoiceLoss,
|
||||
TFNextSentencePredictionLoss,
|
||||
TFPreTrainedModel,
|
||||
TFQuestionAnsweringLoss,
|
||||
TFSequenceClassificationLoss,
|
||||
@@ -1119,7 +1120,7 @@ class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
|
||||
"""MobileBert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
MOBILEBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
|
||||
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextSentencePredictionLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
@@ -1128,7 +1129,20 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
next_sentence_label=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
|
||||
@@ -1146,18 +1160,44 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
|
||||
|
||||
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
|
||||
"""
|
||||
return_dict = kwargs.get("return_dict")
|
||||
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
|
||||
outputs = self.mobilebert(inputs, **kwargs)
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
|
||||
|
||||
outputs = self.mobilebert(
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
seq_relationship_score = self.cls(pooled_output)
|
||||
seq_relationship_scores = self.cls(pooled_output)
|
||||
|
||||
next_sentence_loss = (
|
||||
None
|
||||
if next_sentence_label is None
|
||||
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (seq_relationship_score,) + outputs[2:]
|
||||
output = (seq_relationship_scores,) + outputs[2:]
|
||||
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
||||
|
||||
return TFNextSentencePredictorOutput(
|
||||
logits=seq_relationship_score,
|
||||
loss=next_sentence_loss,
|
||||
logits=seq_relationship_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -307,6 +307,8 @@ class TFNextSentencePredictorOutput(ModelOutput):
|
||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||
|
||||
Args:
|
||||
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
|
||||
Next sentence prediction loss.
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
|
||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
||||
before SoftMax).
|
||||
@@ -323,6 +325,7 @@ class TFNextSentencePredictorOutput(ModelOutput):
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: tf.Tensor = None
|
||||
logits: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
@@ -215,6 +215,27 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
|
||||
"""
|
||||
|
||||
|
||||
class TFNextSentencePredictionLoss:
|
||||
"""
|
||||
Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.
|
||||
|
||||
.. note::
|
||||
Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
|
||||
"""
|
||||
|
||||
def compute_loss(self, labels, logits):
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
# make sure only labels that are not equal to -100
|
||||
# are taken into account as loss
|
||||
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
||||
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
|
||||
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
|
||||
|
||||
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
|
||||
|
||||
|
||||
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
|
||||
"""
|
||||
Detect missing and unexpected layers.
|
||||
|
||||
@@ -35,6 +35,7 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
@@ -95,6 +96,8 @@ class TFModelTesterMixin:
|
||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
|
||||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in [
|
||||
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||
|
||||
Reference in New Issue
Block a user