Add pretraining loss computation for TF Bert pretraining (#8470)
* Add pretraining loss computation for TF Bert pretraining * Fix labels creation * Fix T5 model * restore T5 kwargs * try a generic fix for pretraining models * Apply style * Overide the prepare method for the BERT tests
This commit is contained in:
@@ -26,6 +26,7 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TF_MODEL_FOR_PRETRAINING_MAPPING
|
||||
from transformers.modeling_tf_bert import (
|
||||
TFBertForMaskedLM,
|
||||
TFBertForMultipleChoice,
|
||||
@@ -274,6 +275,16 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
|
||||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
|
||||
|
||||
Reference in New Issue
Block a user