Add next sentence prediction loss computation (#8462)
* Add next sentence prediction loss computation * Apply style * Fix tests * Add forgotten import * Add forgotten import * Use a new parameter * Remove kwargs and use positional arguments
This commit is contained in:
@@ -35,6 +35,7 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
@@ -95,6 +96,8 @@ class TFModelTesterMixin:
|
||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
|
||||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in [
|
||||
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||
|
||||
Reference in New Issue
Block a user