Remove head mask in generative models (#35786)

* just squash into one commit

* delete print
This commit is contained in:
Raushan Turganbay
2025-05-15 10:44:19 +02:00
committed by GitHub
parent 0173a99e73
commit 955e61b0da
47 changed files with 103 additions and 294 deletions

View File

@@ -1414,6 +1414,7 @@ class GenerationTesterMixin:
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
config._attn_implementation = "eager" # head mask works only in eager mode and will be removed soon
text_config = config.get_text_config()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise

View File

@@ -58,28 +58,16 @@ def prepare_bart_inputs_dict(
decoder_input_ids=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
@@ -167,10 +155,9 @@ class BartModelTester:
model = BartModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()

View File

@@ -119,11 +119,10 @@ class TFBartModelTester:
input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
@@ -158,9 +157,6 @@ def prepare_bart_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -172,20 +168,11 @@ def prepare_bart_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}

View File

@@ -135,9 +135,7 @@ class BioGptModelTester:
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_biogpt_model_attention_mask_past(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
):
def create_and_check_biogpt_model_attention_mask_past(self, config, input_ids, input_mask, token_type_ids, *args):
model = BioGptModel(config=config)
model.to(torch_device)
model.eval()
@@ -177,9 +175,7 @@ class BioGptModelTester:
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_biogpt_model_past_large_inputs(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
):
def create_and_check_biogpt_model_past_large_inputs(self, config, input_ids, input_mask, token_type_ids, *args):
model = BioGptModel(config=config).to(torch_device).eval()
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
@@ -213,7 +209,7 @@ class BioGptModelTester:
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = BioGptForCausalLM(config)
model.to(torch_device)
@@ -233,9 +229,7 @@ class BioGptModelTester:
self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
def create_and_check_biogpt_for_token_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
):
def create_and_check_biogpt_for_token_classification(self, config, input_ids, input_mask, token_type_ids, *args):
config.num_labels = self.num_labels
model = BioGptForTokenClassification(config)
model.to(torch_device)

View File

@@ -128,13 +128,10 @@ class GPTBigCodeModelTester:
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
return (
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
mc_token_ids,
sequence_labels,
@@ -174,19 +171,19 @@ class GPTBigCodeModelTester:
config.vocab_size = 300
return config
def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, token_type_ids, *args):
model = GPTBigCodeModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(len(result.past_key_values), config.n_layer)
def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, token_type_ids, *args):
model = GPTBigCodeModel(config=config)
model.to(torch_device)
model.eval()
@@ -223,7 +220,7 @@ class GPTBigCodeModelTester:
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_gpt_bigcode_model_attention_mask_past(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
self, config, input_ids, input_mask, token_type_ids, *args
):
model = GPTBigCodeModel(config=config)
model.to(torch_device)
@@ -265,7 +262,7 @@ class GPTBigCodeModelTester:
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_gpt_bigcode_model_past_large_inputs(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
self, config, input_ids, input_mask, token_type_ids, *args
):
model = GPTBigCodeModel(config=config)
model.to(torch_device)
@@ -302,7 +299,7 @@ class GPTBigCodeModelTester:
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_lm_head_model(self, config, input_ids, input_mask, token_type_ids, *args):
model = GPTBigCodeForCausalLM(config)
model.to(torch_device)
model.eval()
@@ -312,7 +309,7 @@ class GPTBigCodeModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPTBigCodeForCausalLM(config)
model.to(torch_device)
@@ -325,7 +322,7 @@ class GPTBigCodeModelTester:
result.loss.backward()
def create_and_check_gpt_bigcode_for_sequence_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
self, config, input_ids, input_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
config.num_labels = self.num_labels
model = GPTBigCodeForSequenceClassification(config)
@@ -335,7 +332,7 @@ class GPTBigCodeModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_gpt_bigcode_for_token_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
self, config, input_ids, input_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
config.num_labels = self.num_labels
model = GPTBigCodeForTokenClassification(config)
@@ -359,7 +356,6 @@ class GPTBigCodeModelTester:
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
mc_token_ids,
sequence_labels,
@@ -370,7 +366,6 @@ class GPTBigCodeModelTester:
inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"head_mask": head_mask,
}
return config, inputs_dict

View File

@@ -51,28 +51,16 @@ def prepare_m2m_100_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
@@ -166,10 +154,9 @@ class M2M100ModelTester:
model = M2M100Model(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()

View File

@@ -55,28 +55,16 @@ def prepare_mbart_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
@@ -158,10 +146,9 @@ class MBartModelTester:
model = MBartModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()

View File

@@ -107,11 +107,10 @@ class TFMBartModelTester:
input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
@@ -123,9 +122,6 @@ def prepare_mbart_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -137,20 +133,11 @@ def prepare_mbart_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}

View File

@@ -76,27 +76,19 @@ def prepare_musicgen_decoder_inputs_dict(
config,
input_ids,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :]
attention_mask = attention_mask.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
if encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"head_mask": head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
@@ -467,9 +459,6 @@ def prepare_musicgen_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
labels=None,
):
if decoder_attention_mask is None:
@@ -477,26 +466,11 @@ def prepare_musicgen_inputs_dict(
-1, config.decoder.num_codebooks, decoder_input_ids.shape[-1]
)[:, 0, :]
decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id)
if head_mask is None:
head_mask = torch.ones(
config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device
)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"labels": labels,
}

View File

@@ -80,15 +80,12 @@ def prepare_musicgen_melody_decoder_inputs_dict(
config,
input_ids,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :]
attention_mask = attention_mask.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
if encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device)
return {
@@ -96,7 +93,6 @@ def prepare_musicgen_melody_decoder_inputs_dict(
"attention_mask": attention_mask,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"head_mask": head_mask,
}
@@ -475,8 +471,6 @@ def prepare_musicgen_melody_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
labels=None,
):
if decoder_attention_mask is None:
@@ -484,21 +478,11 @@ def prepare_musicgen_melody_inputs_dict(
-1, config.decoder.num_codebooks, decoder_input_ids.shape[-1]
)[:, 0, :]
decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id)
if head_mask is None:
head_mask = torch.ones(
config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device
)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"labels": labels,
}

View File

@@ -52,15 +52,12 @@ def prepare_opt_inputs_dict(
decoder_input_ids=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
}
@@ -156,10 +153,9 @@ class OPTModelTester:
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
@@ -187,7 +183,7 @@ class OPTModelTester:
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
# test no attention_mask works
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
_, past_key_values = outputs.to_tuple()
output_from_no_past = model(next_input_ids)["last_hidden_state"]

View File

@@ -62,25 +62,13 @@ def prepare_whisper_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if decoder_attention_mask is None:
decoder_attention_mask = tf.where(decoder_input_ids != config.pad_token_id, 1, 0)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_features": input_features,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
@@ -350,9 +338,6 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
def test_training(self):
pass
def test_generate_with_head_masking(self):
pass
@unittest.skip("fp16 is not yet supported for TF models")
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()

View File

@@ -159,26 +159,14 @@ def prepare_whisper_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
# "input_ids": input_features,
"input_features": input_features,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
@@ -3235,12 +3223,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertTrue((eager_generated_ids[permutation_idx, :] == static_generated_ids).all())
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
return {"input_features": input_features, "head_mask": head_mask}
@require_torch
class WhisperEncoderModelTester:
def __init__(
@@ -3314,10 +3296,7 @@ class WhisperEncoderModelTester:
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length])
config = self.get_config()
inputs_dict = prepare_whisper_encoder_inputs_dict(
config,
input_features=input_features,
)
inputs_dict = {"input_features": input_features}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
@@ -3427,8 +3406,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, unittest.TestCase):
encoder_inputs = {"input_features": inputs["input_features"]}
del inputs["input_features"]
if "head_mask" in inputs:
encoder_inputs["head_mask"] = inputs["head_mask"]
if "attention_mask" in inputs:
encoder_inputs["attention_mask"] = inputs["attention_mask"]
if "output_attentions" in inputs:
@@ -3523,9 +3500,6 @@ class WhisperStandaloneDecoderModelTester:
)
inputs_dict.pop("input_features")
inputs_dict.pop("head_mask")
inputs_dict.pop("decoder_head_mask")
inputs_dict.pop("cross_attn_head_mask")
inputs_dict["attention_mask"] = inputs_dict.pop("decoder_attention_mask")
inputs_dict["input_ids"] = inputs_dict.pop("decoder_input_ids")

View File

@@ -1444,6 +1444,7 @@ class ModelTesterMixin:
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init._attn_implementation = "eager" # head mask works only in eager mode and will be removed soon
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)