has_attentions - consistent test skipping logic and tf tests (#17495)
This commit is contained in:
@@ -158,6 +158,10 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def create_and_test_config_common_properties(self):
|
def create_and_test_config_common_properties(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@unittest.skip(reason="ConvNext does not output attentions")
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="ConvNext does not use inputs_embeds")
|
@unittest.skip(reason="ConvNext does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -173,6 +173,10 @@ class CvtModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def create_and_test_config_common_properties(self):
|
def create_and_test_config_common_properties(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@unittest.skip(reason="Cvt does not output attentions")
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Cvt does not use inputs_embeds")
|
@unittest.skip(reason="Cvt does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -695,6 +695,10 @@ class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
expected_arg_names = ["pixel_values"]
|
expected_arg_names = ["pixel_values"]
|
||||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Flava does not output attentions")
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
# No embedding in multimodal model
|
# No embedding in multimodal model
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -142,6 +142,10 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="PoolFormer does not output attentions")
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip("PoolFormer does not use inputs_embeds")
|
@unittest.skip("PoolFormer does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -147,6 +147,10 @@ class RegNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def create_and_test_config_common_properties(self):
|
def create_and_test_config_common_properties(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@unittest.skip(reason="RegNet does not output attentions")
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="RegNet does not use inputs_embeds")
|
@unittest.skip(reason="RegNet does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -147,6 +147,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def create_and_test_config_common_properties(self):
|
def create_and_test_config_common_properties(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@unittest.skip(reason="ResNet does not output attentions")
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="ResNet does not use inputs_embeds")
|
@unittest.skip(reason="ResNet does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -144,6 +144,10 @@ class VanModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def create_and_test_config_common_properties(self):
|
def create_and_test_config_common_properties(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@unittest.skip(reason="Van does not output attentions")
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Van does not use inputs_embeds")
|
@unittest.skip(reason="Van does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -485,123 +485,119 @@ class ModelTesterMixin:
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
if not self.has_attentions:
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
pass
|
config.return_dict = True
|
||||||
|
|
||||||
else:
|
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||||
|
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||||
|
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
|
||||||
|
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||||
|
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
||||||
|
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
||||||
|
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = False
|
||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
# check that output_attentions also work using config
|
||||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
del inputs_dict["output_attentions"]
|
||||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
config.output_attentions = True
|
||||||
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
|
model = model_class(config)
|
||||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
model.to(torch_device)
|
||||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
model.eval()
|
||||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
with torch.no_grad():
|
||||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
if chunk_length is not None:
|
||||||
inputs_dict["output_attentions"] = True
|
self.assertListEqual(
|
||||||
inputs_dict["output_hidden_states"] = False
|
list(attentions[0].shape[-4:]),
|
||||||
config.return_dict = True
|
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||||
model = model_class(config)
|
)
|
||||||
model.to(torch_device)
|
else:
|
||||||
model.eval()
|
self.assertListEqual(
|
||||||
with torch.no_grad():
|
list(attentions[0].shape[-3:]),
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
)
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
out_len = len(outputs)
|
||||||
|
|
||||||
# check that output_attentions also work using config
|
if self.is_encoder_decoder:
|
||||||
del inputs_dict["output_attentions"]
|
correct_outlen = 5
|
||||||
config.output_attentions = True
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
||||||
|
|
||||||
if chunk_length is not None:
|
# loss is at first position
|
||||||
self.assertListEqual(
|
if "labels" in inputs_dict:
|
||||||
list(attentions[0].shape[-4:]),
|
correct_outlen += 1 # loss is added to beginning
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
# Question Answering model returns start_logits and end_logits
|
||||||
)
|
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||||
else:
|
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||||
self.assertListEqual(
|
if "past_key_values" in outputs:
|
||||||
list(attentions[0].shape[-3:]),
|
correct_outlen += 1 # past_key_values have been returned
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
out_len = len(outputs)
|
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
self.assertEqual(out_len, correct_outlen)
|
||||||
correct_outlen = 5
|
|
||||||
|
|
||||||
# loss is at first position
|
# decoder attentions
|
||||||
if "labels" in inputs_dict:
|
decoder_attentions = outputs.decoder_attentions
|
||||||
correct_outlen += 1 # loss is added to beginning
|
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||||
# Question Answering model returns start_logits and end_logits
|
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||||
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
self.assertListEqual(
|
||||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
list(decoder_attentions[0].shape[-3:]),
|
||||||
if "past_key_values" in outputs:
|
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||||
correct_outlen += 1 # past_key_values have been returned
|
)
|
||||||
|
|
||||||
self.assertEqual(out_len, correct_outlen)
|
# cross attentions
|
||||||
|
cross_attentions = outputs.cross_attentions
|
||||||
|
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||||
|
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(cross_attentions[0].shape[-3:]),
|
||||||
|
[
|
||||||
|
self.model_tester.num_attention_heads,
|
||||||
|
decoder_seq_length,
|
||||||
|
encoder_key_length,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# decoder attentions
|
# Check attention is always last and order is fine
|
||||||
decoder_attentions = outputs.decoder_attentions
|
inputs_dict["output_attentions"] = True
|
||||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
inputs_dict["output_hidden_states"] = True
|
||||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
model = model_class(config)
|
||||||
self.assertListEqual(
|
model.to(torch_device)
|
||||||
list(decoder_attentions[0].shape[-3:]),
|
model.eval()
|
||||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
with torch.no_grad():
|
||||||
)
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
# cross attentions
|
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||||
cross_attentions = outputs.cross_attentions
|
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
elif self.is_encoder_decoder:
|
||||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
added_hidden_states = 2
|
||||||
self.assertListEqual(
|
else:
|
||||||
list(cross_attentions[0].shape[-3:]),
|
added_hidden_states = 1
|
||||||
[
|
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||||
self.model_tester.num_attention_heads,
|
|
||||||
decoder_seq_length,
|
|
||||||
encoder_key_length,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check attention is always last and order is fine
|
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
inputs_dict["output_attentions"] = True
|
|
||||||
inputs_dict["output_hidden_states"] = True
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
|
|
||||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
if chunk_length is not None:
|
||||||
elif self.is_encoder_decoder:
|
self.assertListEqual(
|
||||||
added_hidden_states = 2
|
list(self_attentions[0].shape[-4:]),
|
||||||
else:
|
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||||
added_hidden_states = 1
|
)
|
||||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
else:
|
||||||
|
self.assertListEqual(
|
||||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
list(self_attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
)
|
||||||
if chunk_length is not None:
|
|
||||||
self.assertListEqual(
|
|
||||||
list(self_attentions[0].shape[-4:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.assertListEqual(
|
|
||||||
list(self_attentions[0].shape[-3:]),
|
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
|
||||||
)
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_torchscript_simple(self):
|
def test_torchscript_simple(self):
|
||||||
|
|||||||
@@ -978,9 +978,10 @@ class TFModelTesterMixin:
|
|||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
if self.has_attentions:
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||||
|
|
||||||
# Not all models accept "labels" in the forward pass (yet :) )
|
# Not all models accept "labels" in the forward pass (yet :) )
|
||||||
if "labels" in inspect.signature(model.call).parameters.keys():
|
if "labels" in inspect.signature(model.call).parameters.keys():
|
||||||
@@ -992,15 +993,16 @@ class TFModelTesterMixin:
|
|||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
if self.has_attentions:
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
check_equivalence(
|
check_equivalence(
|
||||||
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user