Add next sentence prediction loss computation (#8462)
* Add next sentence prediction loss computation * Apply style * Fix tests * Add forgotten import * Add forgotten import * Use a new parameter * Remove kwargs and use positional arguments
This commit is contained in:
@@ -46,6 +46,7 @@ from .modeling_tf_utils import (
|
|||||||
TFCausalLanguageModelingLoss,
|
TFCausalLanguageModelingLoss,
|
||||||
TFMaskedLanguageModelingLoss,
|
TFMaskedLanguageModelingLoss,
|
||||||
TFMultipleChoiceLoss,
|
TFMultipleChoiceLoss,
|
||||||
|
TFNextSentencePredictionLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
TFSequenceClassificationLoss,
|
TFSequenceClassificationLoss,
|
||||||
@@ -1036,7 +1037,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
@@ -1045,7 +1046,20 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
next_sentence_label=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Return:
|
Return:
|
||||||
|
|
||||||
@@ -1064,17 +1078,43 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
|||||||
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
|
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
|
||||||
>>> assert logits[0][0] < logits[0][1] # the next sentence was random
|
>>> assert logits[0][0] < logits[0][1] # the next sentence was random
|
||||||
"""
|
"""
|
||||||
return_dict = kwargs.get("return_dict")
|
|
||||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||||
outputs = self.bert(inputs, **kwargs)
|
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
|
||||||
|
if len(inputs) > 9:
|
||||||
|
inputs = inputs[:9]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
|
||||||
|
|
||||||
|
outputs = self.bert(
|
||||||
|
inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
seq_relationship_score = self.nsp(pooled_output)
|
seq_relationship_scores = self.nsp(pooled_output)
|
||||||
|
|
||||||
|
next_sentence_loss = (
|
||||||
|
None
|
||||||
|
if next_sentence_label is None
|
||||||
|
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
|
||||||
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (seq_relationship_score,) + outputs[2:]
|
output = (seq_relationship_scores,) + outputs[2:]
|
||||||
|
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
||||||
|
|
||||||
return TFNextSentencePredictorOutput(
|
return TFNextSentencePredictorOutput(
|
||||||
logits=seq_relationship_score,
|
loss=next_sentence_loss,
|
||||||
|
logits=seq_relationship_scores,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from .modeling_tf_outputs import (
|
|||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
TFMaskedLanguageModelingLoss,
|
TFMaskedLanguageModelingLoss,
|
||||||
TFMultipleChoiceLoss,
|
TFMultipleChoiceLoss,
|
||||||
|
TFNextSentencePredictionLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
TFSequenceClassificationLoss,
|
TFSequenceClassificationLoss,
|
||||||
@@ -1119,7 +1120,7 @@ class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
|
|||||||
"""MobileBert Model with a `next sentence prediction (classification)` head on top. """,
|
"""MobileBert Model with a `next sentence prediction (classification)` head on top. """,
|
||||||
MOBILEBERT_START_DOCSTRING,
|
MOBILEBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
|
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextSentencePredictionLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
@@ -1128,7 +1129,20 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
next_sentence_label=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Return:
|
Return:
|
||||||
|
|
||||||
@@ -1146,18 +1160,44 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
|
|||||||
|
|
||||||
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
|
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
|
||||||
"""
|
"""
|
||||||
return_dict = kwargs.get("return_dict")
|
|
||||||
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
|
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
|
||||||
outputs = self.mobilebert(inputs, **kwargs)
|
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
|
||||||
|
if len(inputs) > 9:
|
||||||
|
inputs = inputs[:9]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
|
||||||
|
|
||||||
|
outputs = self.mobilebert(
|
||||||
|
inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
seq_relationship_score = self.cls(pooled_output)
|
seq_relationship_scores = self.cls(pooled_output)
|
||||||
|
|
||||||
|
next_sentence_loss = (
|
||||||
|
None
|
||||||
|
if next_sentence_label is None
|
||||||
|
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
|
||||||
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (seq_relationship_score,) + outputs[2:]
|
output = (seq_relationship_scores,) + outputs[2:]
|
||||||
|
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
||||||
|
|
||||||
return TFNextSentencePredictorOutput(
|
return TFNextSentencePredictorOutput(
|
||||||
logits=seq_relationship_score,
|
loss=next_sentence_loss,
|
||||||
|
logits=seq_relationship_scores,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -307,6 +307,8 @@ class TFNextSentencePredictorOutput(ModelOutput):
|
|||||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
|
||||||
|
Next sentence prediction loss.
|
||||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
|
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
|
||||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
||||||
before SoftMax).
|
before SoftMax).
|
||||||
@@ -323,6 +325,7 @@ class TFNextSentencePredictorOutput(ModelOutput):
|
|||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
loss: tf.Tensor = None
|
||||||
logits: tf.Tensor = None
|
logits: tf.Tensor = None
|
||||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|||||||
@@ -215,6 +215,27 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TFNextSentencePredictionLoss:
|
||||||
|
"""
|
||||||
|
Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_loss(self, labels, logits):
|
||||||
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
|
)
|
||||||
|
# make sure only labels that are not equal to -100
|
||||||
|
# are taken into account as loss
|
||||||
|
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
||||||
|
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
|
||||||
|
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
|
||||||
|
|
||||||
|
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
|
||||||
|
|
||||||
|
|
||||||
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
|
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
|
||||||
"""
|
"""
|
||||||
Detect missing and unexpected layers.
|
Detect missing and unexpected layers.
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
|
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
@@ -95,6 +96,8 @@ class TFModelTesterMixin:
|
|||||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
|
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
|
||||||
|
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
elif model_class in [
|
elif model_class in [
|
||||||
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||||
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||||
|
|||||||
Reference in New Issue
Block a user