TF: Add missing cast to GPT-J (#18201)

* Fix TF GPT-J tests

* add try/finally block
This commit is contained in:
Joao Gante
2022-07-19 15:58:42 +01:00
committed by GitHub
parent 05ed569c79
commit ec6cd7633f
2 changed files with 11 additions and 10 deletions

View File

@@ -274,16 +274,17 @@ class TFCoreModelTesterMixin:
def test_mixed_precision(self):
tf.keras.mixed_precision.set_global_policy("mixed_float16")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# try/finally block to ensure subsequent tests run in float32
try:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
outputs = model(class_inputs_dict)
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
outputs = model(class_inputs_dict)
self.assertIsNotNone(outputs)
tf.keras.mixed_precision.set_global_policy("float32")
self.assertIsNotNone(outputs)
finally:
tf.keras.mixed_precision.set_global_policy("float32")
@slow
def test_train_pipeline_custom_model(self):