[generation] bring back tests on vision models (#38603)
* bring back geenration tests on VLMs * remove head mask tests overwritten
This commit is contained in:
committed by
GitHub
parent
90c4b90a10
commit
dbfc79c17c
@@ -489,39 +489,6 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config = config_and_inputs[0]
|
||||
model = UMT5ForConditionalGeneration(config).eval()
|
||||
model.to(torch_device)
|
||||
|
||||
head_masking = {
|
||||
"head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device),
|
||||
"decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
||||
"cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
||||
}
|
||||
|
||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||
head_masks = {name: mask}
|
||||
# Explicitly pass decoder_head_mask as it is required from T5 model when head_mask specified
|
||||
if name == "head_mask":
|
||||
head_masks["decoder_head_mask"] = torch.ones(
|
||||
config.num_decoder_layers, config.num_heads, device=torch_device
|
||||
)
|
||||
|
||||
out = model.generate(
|
||||
config_and_inputs[1]["input_ids"],
|
||||
num_beams=1,
|
||||
max_length=3,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
**head_masks,
|
||||
)
|
||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user