[generation] bring back tests on vision models (#38603)
* bring back geenration tests on VLMs * remove head mask tests overwritten
This commit is contained in:
committed by
GitHub
parent
90c4b90a10
commit
dbfc79c17c
@@ -499,7 +499,7 @@ class GenerationTesterMixin:
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
@@ -523,7 +523,7 @@ class GenerationTesterMixin:
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -563,7 +563,7 @@ class GenerationTesterMixin:
|
||||
use_cache=True, # Enable cache
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(
|
||||
@@ -580,7 +580,7 @@ class GenerationTesterMixin:
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
@@ -605,7 +605,7 @@ class GenerationTesterMixin:
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -630,7 +630,7 @@ class GenerationTesterMixin:
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
output_generate = self._beam_search_generate(model=model, inputs_dict=inputs_dict, beam_kwargs=beam_kwargs)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
@@ -655,7 +655,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -704,7 +704,7 @@ class GenerationTesterMixin:
|
||||
use_cache=True, # Enable cache
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(
|
||||
@@ -757,7 +757,7 @@ class GenerationTesterMixin:
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
@@ -784,7 +784,7 @@ class GenerationTesterMixin:
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -838,7 +838,7 @@ class GenerationTesterMixin:
|
||||
inputs_dict=inputs_dict,
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
@@ -851,7 +851,7 @@ class GenerationTesterMixin:
|
||||
inputs_dict=inputs_dict,
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
@@ -876,7 +876,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -921,7 +921,7 @@ class GenerationTesterMixin:
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
@@ -945,7 +945,7 @@ class GenerationTesterMixin:
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
@@ -985,7 +985,7 @@ class GenerationTesterMixin:
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
@@ -1029,7 +1029,7 @@ class GenerationTesterMixin:
|
||||
inputs_dict=inputs_dict,
|
||||
use_cache=True, # Enable cache
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
@@ -1065,7 +1065,7 @@ class GenerationTesterMixin:
|
||||
use_cache=True, # Enable cache
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(
|
||||
@@ -1297,7 +1297,7 @@ class GenerationTesterMixin:
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# Encoder-decoder models are not supported
|
||||
if config.is_encoder_decoder:
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.skipTest("DoLa is not supported for encoder-decoder models")
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@@ -1427,52 +1427,6 @@ class GenerationTesterMixin:
|
||||
# PLD shouldn't propose any new tokens based on eos-match
|
||||
self.assertTrue(output_prompt_lookup.shape[-1] == 10)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_with_head_masking(self):
|
||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
config._attn_implementation = "eager" # head mask works only in eager mode and will be removed soon
|
||||
text_config = config.get_text_config()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
# We want to test only encoder-decoder models
|
||||
if not text_config.is_encoder_decoder:
|
||||
continue
|
||||
model = model_class(config).to(torch_device)
|
||||
|
||||
head_masking = {
|
||||
"head_mask": torch.zeros(
|
||||
text_config.encoder_layers, text_config.encoder_attention_heads, device=torch_device
|
||||
),
|
||||
"decoder_head_mask": torch.zeros(
|
||||
text_config.decoder_layers, text_config.decoder_attention_heads, device=torch_device
|
||||
),
|
||||
"cross_attn_head_mask": torch.zeros(
|
||||
text_config.decoder_layers, text_config.decoder_attention_heads, device=torch_device
|
||||
),
|
||||
}
|
||||
|
||||
signature = inspect.signature(model.forward)
|
||||
# We want to test only models where encoder/decoder head masking is implemented
|
||||
if not set(head_masking.keys()) < {*signature.parameters.keys()}:
|
||||
continue
|
||||
|
||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||
out = model.generate(
|
||||
num_beams=1,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
remove_invalid_values=True,
|
||||
**{name: mask},
|
||||
**inputs_dict,
|
||||
)
|
||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_left_padding_compatibility(self):
|
||||
# NOTE: left-padding results in small numerical differences. This is expected.
|
||||
@@ -1491,7 +1445,7 @@ class GenerationTesterMixin:
|
||||
decoder_only_classes = []
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, _ = self.prepare_config_and_inputs_for_generate()
|
||||
if config.is_encoder_decoder:
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
continue
|
||||
else:
|
||||
decoder_only_classes.append(model_class)
|
||||
@@ -1696,7 +1650,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
|
||||
# decoder)
|
||||
if config.is_encoder_decoder:
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
continue
|
||||
config.is_decoder = True
|
||||
|
||||
@@ -1790,7 +1744,7 @@ class GenerationTesterMixin:
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@@ -1952,7 +1906,7 @@ class GenerationTesterMixin:
|
||||
if "token_type_ids" in inputs_dict:
|
||||
del inputs_dict["token_type_ids"]
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder")
|
||||
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
|
||||
# but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern
|
||||
@@ -2031,7 +1985,7 @@ class GenerationTesterMixin:
|
||||
set_config_for_less_flaky_test(config)
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||
|
||||
config.is_decoder = True
|
||||
@@ -2183,7 +2137,7 @@ class GenerationTesterMixin:
|
||||
if not has_defined_cache_implementation:
|
||||
decoder_cache = (
|
||||
gen_out.past_key_values.self_attention_cache
|
||||
if config.is_encoder_decoder
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder
|
||||
else gen_out.past_key_values
|
||||
)
|
||||
self.assertTrue(isinstance(decoder_cache, DynamicCache))
|
||||
@@ -2209,7 +2163,7 @@ class GenerationTesterMixin:
|
||||
# sanity checks
|
||||
decoder_cache = (
|
||||
gen_out.past_key_values.self_attention_cache
|
||||
if config.is_encoder_decoder
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder
|
||||
else gen_out.past_key_values
|
||||
)
|
||||
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
||||
@@ -2283,7 +2237,7 @@ class GenerationTesterMixin:
|
||||
else:
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
else:
|
||||
@@ -5154,7 +5108,6 @@ class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||
def test_update_candidate_strategy_no_matches_short(self, sklearn_available):
|
||||
print("test_update_candidate_strategy_no_matches_short")
|
||||
self.original_matches = []
|
||||
self.candidate_generator.matches = self.original_matches
|
||||
self.num_matches = 0
|
||||
|
||||
Reference in New Issue
Block a user