Reduce the time spent for the TF slow tests (#10152)

* rework savedmodel slow test

* Improve savedmodel tests

* Remove useless content
This commit is contained in:
Julien Plu
2021-02-18 15:52:57 +01:00
committed by GitHub
parent 14ed3b978e
commit 2acae50a0c
7 changed files with 91 additions and 166 deletions

View File

@@ -273,13 +273,13 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
@slow
def test_saved_model_with_attentions_output(self):
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
config.output_hidden_states = False
if hasattr(config, "use_cache"):
config.use_cache = False
config.use_cache = True
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
@@ -291,14 +291,32 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
model = tf.keras.models.load_model(os.path.join(tmpdirname, "saved_model", "1"))
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict)
output = outputs["attentions"]
if self.is_encoder_decoder:
output_hidden_states = outputs["encoder_hidden_states"]
output_attentions = outputs["encoder_attentions"]
else:
output_hidden_states = outputs["hidden_states"]
output_attentions = outputs["attentions"]
self.assertEqual(len(outputs), num_out)
self.assertEqual(len(output), self.model_tester.num_hidden_layers)
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(output_hidden_states), expected_num_layers)
self.assertListEqual(
list(output[0].shape[-3:]),
list(output_hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(output_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length],
)