[Tests] Add attentions_option to ModelTesterMixin (#15909)
* Add attentions_option to common tester * Fix tests, apply suggestion * Apply suggestion from code review Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
from transformers import ConvNextConfig
|
from transformers import ConvNextConfig
|
||||||
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
||||||
@@ -142,6 +141,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
has_attentions = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = ConvNextModelTester(self)
|
self.model_tester = ConvNextModelTester(self)
|
||||||
@@ -183,10 +183,6 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skip(reason="Model doesn't have attention layers")
|
|
||||||
def test_attention_outputs(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
@@ -219,81 +215,6 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
check_hidden_states_output(inputs_dict, config, model_class)
|
check_hidden_states_output(inputs_dict, config, model_class)
|
||||||
|
|
||||||
def test_retain_grad_hidden_states_attentions(self):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.output_hidden_states = True
|
|
||||||
config.output_attentions = True
|
|
||||||
|
|
||||||
# no need to test all models as different heads yield the same functionality
|
|
||||||
model_class = self.all_model_classes[0]
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
|
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
outputs = model(**inputs)
|
|
||||||
output = outputs[0]
|
|
||||||
|
|
||||||
hidden_states = outputs.hidden_states[0]
|
|
||||||
hidden_states.retain_grad()
|
|
||||||
|
|
||||||
output.flatten()[0].backward(retain_graph=True)
|
|
||||||
|
|
||||||
self.assertIsNotNone(hidden_states.grad)
|
|
||||||
|
|
||||||
def test_model_outputs_equivalence(self):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
def set_nan_tensor_to_zero(t):
|
|
||||||
t[t != t] = 0
|
|
||||||
return t
|
|
||||||
|
|
||||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
|
||||||
with torch.no_grad():
|
|
||||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
|
||||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
|
||||||
|
|
||||||
def recursive_check(tuple_object, dict_object):
|
|
||||||
if isinstance(tuple_object, (List, Tuple)):
|
|
||||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
|
||||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
|
||||||
elif isinstance(tuple_object, Dict):
|
|
||||||
for tuple_iterable_value, dict_iterable_value in zip(
|
|
||||||
tuple_object.values(), dict_object.values()
|
|
||||||
):
|
|
||||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
|
||||||
elif tuple_object is None:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(
|
|
||||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
|
||||||
),
|
|
||||||
msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
recursive_check(tuple_output, dict_output)
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
|
||||||
|
|
||||||
def test_for_image_classification(self):
|
def test_for_image_classification(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
from transformers import is_torch_available, is_vision_available
|
from transformers import is_torch_available, is_vision_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
@@ -130,6 +129,7 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
has_attentions = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = PoolFormerModelTester(self)
|
self.model_tester = PoolFormerModelTester(self)
|
||||||
@@ -150,100 +150,6 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_retain_grad_hidden_states_attentions(self):
|
|
||||||
# Since poolformer doesn't use Attention
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.output_hidden_states = True
|
|
||||||
|
|
||||||
# no need to test all models as different heads yield the same functionality
|
|
||||||
model_class = self.all_model_classes[0]
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
|
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
|
|
||||||
outputs = model(**inputs)
|
|
||||||
|
|
||||||
output = outputs[0]
|
|
||||||
|
|
||||||
hidden_states = outputs.hidden_states[0]
|
|
||||||
|
|
||||||
hidden_states.retain_grad()
|
|
||||||
|
|
||||||
output.flatten()[0].backward(retain_graph=True)
|
|
||||||
|
|
||||||
self.assertIsNotNone(hidden_states.grad)
|
|
||||||
|
|
||||||
def test_model_outputs_equivalence(self):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
def set_nan_tensor_to_zero(t):
|
|
||||||
t[t != t] = 0
|
|
||||||
return t
|
|
||||||
|
|
||||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
|
||||||
with torch.no_grad():
|
|
||||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
|
||||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
|
||||||
|
|
||||||
def recursive_check(tuple_object, dict_object):
|
|
||||||
if isinstance(tuple_object, (List, Tuple)):
|
|
||||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
|
||||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
|
||||||
elif isinstance(tuple_object, Dict):
|
|
||||||
for tuple_iterable_value, dict_iterable_value in zip(
|
|
||||||
tuple_object.values(), dict_object.values()
|
|
||||||
):
|
|
||||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
|
||||||
elif tuple_object is None:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(
|
|
||||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
|
||||||
),
|
|
||||||
msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
recursive_check(tuple_output, dict_output)
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
|
||||||
|
|
||||||
def test_forward_signature(self):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config)
|
|
||||||
signature = inspect.signature(model.forward)
|
|
||||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
|
||||||
arg_names = [*signature.parameters.keys()]
|
|
||||||
|
|
||||||
expected_arg_names = ["pixel_values"]
|
|
||||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
|
||||||
|
|
||||||
@unittest.skip("PoolFormer does not have attention")
|
|
||||||
def test_attention_outputs(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
@@ -297,6 +203,18 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
loss = model(**inputs).loss
|
loss = model(**inputs).loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["pixel_values"]
|
||||||
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ class ModelTesterMixin:
|
|||||||
test_missing_keys = True
|
test_missing_keys = True
|
||||||
test_model_parallel = False
|
test_model_parallel = False
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
has_attentions = True
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
inputs_dict = copy.deepcopy(inputs_dict)
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
@@ -454,119 +455,123 @@ class ModelTesterMixin:
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
if not self.has_attentions:
|
||||||
config.return_dict = True
|
pass
|
||||||
|
|
||||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
else:
|
||||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
|
||||||
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
|
|
||||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
|
||||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
|
||||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
|
||||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
inputs_dict["output_attentions"] = True
|
|
||||||
inputs_dict["output_hidden_states"] = False
|
|
||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
||||||
|
|
||||||
# check that output_attentions also work using config
|
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||||
del inputs_dict["output_attentions"]
|
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||||
config.output_attentions = True
|
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||||
model = model_class(config)
|
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
|
||||||
model.to(torch_device)
|
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||||
model.eval()
|
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
||||||
with torch.no_grad():
|
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
||||||
|
|
||||||
if chunk_length is not None:
|
for model_class in self.all_model_classes:
|
||||||
self.assertListEqual(
|
inputs_dict["output_attentions"] = True
|
||||||
list(attentions[0].shape[-4:]),
|
inputs_dict["output_hidden_states"] = False
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
config.return_dict = True
|
||||||
)
|
model = model_class(config)
|
||||||
else:
|
model.to(torch_device)
|
||||||
self.assertListEqual(
|
model.eval()
|
||||||
list(attentions[0].shape[-3:]),
|
with torch.no_grad():
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
)
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
out_len = len(outputs)
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
# check that output_attentions also work using config
|
||||||
correct_outlen = 5
|
del inputs_dict["output_attentions"]
|
||||||
|
config.output_attentions = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
# loss is at first position
|
if chunk_length is not None:
|
||||||
if "labels" in inputs_dict:
|
self.assertListEqual(
|
||||||
correct_outlen += 1 # loss is added to beginning
|
list(attentions[0].shape[-4:]),
|
||||||
# Question Answering model returns start_logits and end_logits
|
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||||
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
)
|
||||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
else:
|
||||||
if "past_key_values" in outputs:
|
self.assertListEqual(
|
||||||
correct_outlen += 1 # past_key_values have been returned
|
list(attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
|
)
|
||||||
|
out_len = len(outputs)
|
||||||
|
|
||||||
self.assertEqual(out_len, correct_outlen)
|
if self.is_encoder_decoder:
|
||||||
|
correct_outlen = 5
|
||||||
|
|
||||||
# decoder attentions
|
# loss is at first position
|
||||||
decoder_attentions = outputs.decoder_attentions
|
if "labels" in inputs_dict:
|
||||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
correct_outlen += 1 # loss is added to beginning
|
||||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
# Question Answering model returns start_logits and end_logits
|
||||||
self.assertListEqual(
|
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||||
list(decoder_attentions[0].shape[-3:]),
|
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
if "past_key_values" in outputs:
|
||||||
)
|
correct_outlen += 1 # past_key_values have been returned
|
||||||
|
|
||||||
# cross attentions
|
self.assertEqual(out_len, correct_outlen)
|
||||||
cross_attentions = outputs.cross_attentions
|
|
||||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
|
||||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
|
||||||
self.assertListEqual(
|
|
||||||
list(cross_attentions[0].shape[-3:]),
|
|
||||||
[
|
|
||||||
self.model_tester.num_attention_heads,
|
|
||||||
decoder_seq_length,
|
|
||||||
encoder_key_length,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check attention is always last and order is fine
|
# decoder attentions
|
||||||
inputs_dict["output_attentions"] = True
|
decoder_attentions = outputs.decoder_attentions
|
||||||
inputs_dict["output_hidden_states"] = True
|
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||||
model = model_class(config)
|
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||||
model.to(torch_device)
|
self.assertListEqual(
|
||||||
model.eval()
|
list(decoder_attentions[0].shape[-3:]),
|
||||||
with torch.no_grad():
|
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
)
|
||||||
|
|
||||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
# cross attentions
|
||||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
cross_attentions = outputs.cross_attentions
|
||||||
elif self.is_encoder_decoder:
|
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||||
added_hidden_states = 2
|
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||||
else:
|
self.assertListEqual(
|
||||||
added_hidden_states = 1
|
list(cross_attentions[0].shape[-3:]),
|
||||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
[
|
||||||
|
self.model_tester.num_attention_heads,
|
||||||
|
decoder_seq_length,
|
||||||
|
encoder_key_length,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
# Check attention is always last and order is fine
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||||
if chunk_length is not None:
|
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||||
self.assertListEqual(
|
elif self.is_encoder_decoder:
|
||||||
list(self_attentions[0].shape[-4:]),
|
added_hidden_states = 2
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
else:
|
||||||
)
|
added_hidden_states = 1
|
||||||
else:
|
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||||
self.assertListEqual(
|
|
||||||
list(self_attentions[0].shape[-3:]),
|
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
|
||||||
)
|
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||||
|
if chunk_length is not None:
|
||||||
|
self.assertListEqual(
|
||||||
|
list(self_attentions[0].shape[-4:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertListEqual(
|
||||||
|
list(self_attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_torchscript(self):
|
def test_torchscript(self):
|
||||||
@@ -1040,7 +1045,7 @@ class ModelTesterMixin:
|
|||||||
def test_retain_grad_hidden_states_attentions(self):
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.output_hidden_states = True
|
config.output_hidden_states = True
|
||||||
config.output_attentions = True
|
config.output_attentions = self.has_attentions
|
||||||
|
|
||||||
# no need to test all models as different heads yield the same functionality
|
# no need to test all models as different heads yield the same functionality
|
||||||
model_class = self.all_model_classes[0]
|
model_class = self.all_model_classes[0]
|
||||||
@@ -1056,37 +1061,45 @@ class ModelTesterMixin:
|
|||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
# Seq2Seq models
|
# Seq2Seq models
|
||||||
encoder_hidden_states = outputs.encoder_hidden_states[0]
|
encoder_hidden_states = outputs.encoder_hidden_states[0]
|
||||||
encoder_attentions = outputs.encoder_attentions[0]
|
|
||||||
encoder_hidden_states.retain_grad()
|
encoder_hidden_states.retain_grad()
|
||||||
encoder_attentions.retain_grad()
|
|
||||||
|
|
||||||
decoder_hidden_states = outputs.decoder_hidden_states[0]
|
decoder_hidden_states = outputs.decoder_hidden_states[0]
|
||||||
decoder_attentions = outputs.decoder_attentions[0]
|
|
||||||
decoder_hidden_states.retain_grad()
|
decoder_hidden_states.retain_grad()
|
||||||
decoder_attentions.retain_grad()
|
|
||||||
|
|
||||||
cross_attentions = outputs.cross_attentions[0]
|
if self.has_attentions:
|
||||||
cross_attentions.retain_grad()
|
encoder_attentions = outputs.encoder_attentions[0]
|
||||||
|
encoder_attentions.retain_grad()
|
||||||
|
|
||||||
|
decoder_attentions = outputs.decoder_attentions[0]
|
||||||
|
decoder_attentions.retain_grad()
|
||||||
|
|
||||||
|
cross_attentions = outputs.cross_attentions[0]
|
||||||
|
cross_attentions.retain_grad()
|
||||||
|
|
||||||
output.flatten()[0].backward(retain_graph=True)
|
output.flatten()[0].backward(retain_graph=True)
|
||||||
|
|
||||||
self.assertIsNotNone(encoder_hidden_states.grad)
|
self.assertIsNotNone(encoder_hidden_states.grad)
|
||||||
self.assertIsNotNone(encoder_attentions.grad)
|
|
||||||
self.assertIsNotNone(decoder_hidden_states.grad)
|
self.assertIsNotNone(decoder_hidden_states.grad)
|
||||||
self.assertIsNotNone(decoder_attentions.grad)
|
|
||||||
self.assertIsNotNone(cross_attentions.grad)
|
if self.has_attentions:
|
||||||
|
self.assertIsNotNone(encoder_attentions.grad)
|
||||||
|
self.assertIsNotNone(decoder_attentions.grad)
|
||||||
|
self.assertIsNotNone(cross_attentions.grad)
|
||||||
else:
|
else:
|
||||||
# Encoder-/Decoder-only models
|
# Encoder-/Decoder-only models
|
||||||
hidden_states = outputs.hidden_states[0]
|
hidden_states = outputs.hidden_states[0]
|
||||||
attentions = outputs.attentions[0]
|
|
||||||
|
|
||||||
hidden_states.retain_grad()
|
hidden_states.retain_grad()
|
||||||
attentions.retain_grad()
|
|
||||||
|
if self.has_attentions:
|
||||||
|
attentions = outputs.attentions[0]
|
||||||
|
attentions.retain_grad()
|
||||||
|
|
||||||
output.flatten()[0].backward(retain_graph=True)
|
output.flatten()[0].backward(retain_graph=True)
|
||||||
|
|
||||||
self.assertIsNotNone(hidden_states.grad)
|
self.assertIsNotNone(hidden_states.grad)
|
||||||
self.assertIsNotNone(attentions.grad)
|
|
||||||
|
if self.has_attentions:
|
||||||
|
self.assertIsNotNone(attentions.grad)
|
||||||
|
|
||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
(
|
(
|
||||||
@@ -1424,23 +1437,24 @@ class ModelTesterMixin:
|
|||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
if self.has_attentions:
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
check_equivalence(
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||||
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
|
||||||
)
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(
|
||||||
|
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||||
|
)
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
@is_pt_tf_cross_test
|
||||||
def test_pt_tf_model_equivalence(self):
|
def test_pt_tf_model_equivalence(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user