Reduce the time spent for the TF slow tests (#10152)
* rework savedmodel slow test * Improve savedmodel tests * Remove useless content
This commit is contained in:
@@ -273,13 +273,13 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
def test_saved_model_creation_extended(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = False
|
||||
|
||||
if hasattr(config, "use_cache"):
|
||||
config.use_cache = False
|
||||
config.use_cache = True
|
||||
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
@@ -291,14 +291,32 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=True)
|
||||
model = tf.keras.models.load_model(os.path.join(tmpdirname, "saved_model", "1"))
|
||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
||||
model = tf.keras.models.load_model(saved_model_dir)
|
||||
outputs = model(class_inputs_dict)
|
||||
output = outputs["attentions"]
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
output_hidden_states = outputs["encoder_hidden_states"]
|
||||
output_attentions = outputs["encoder_attentions"]
|
||||
else:
|
||||
output_hidden_states = outputs["hidden_states"]
|
||||
output_attentions = outputs["attentions"]
|
||||
|
||||
self.assertEqual(len(outputs), num_out)
|
||||
self.assertEqual(len(output), self.model_tester.num_hidden_layers)
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
|
||||
self.assertEqual(len(output_hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(output[0].shape[-3:]),
|
||||
list(output_hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(output_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user