fixing CTRL tests and OpenAI GPT tests

This commit is contained in:
thomwolf
2019-10-09 13:51:05 +02:00
parent 6dce6dda1b
commit c19b8e4ae0
4 changed files with 31 additions and 25 deletions

View File

@@ -81,8 +81,9 @@ class TFCommonTestCases:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config, output_hidden_states=True)
pt_model = pt_model_class(config, output_hidden_states=True)
config.output_hidden_states = True
tf_model = model_class(config)
pt_model = pt_model_class(config)
# Check we can load pt model in tf and vice-versa (architecture similar)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
@@ -96,7 +97,7 @@ class TFCommonTestCases:
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(inputs_dict)
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2)
self.assertLessEqual(max_diff, 2e-5)
def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()