Remove head mask in generative models (#35786)
* just squash into one commit * delete print
This commit is contained in:
committed by
GitHub
parent
0173a99e73
commit
955e61b0da
@@ -58,28 +58,16 @@ def prepare_bart_inputs_dict(
|
||||
decoder_input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -167,10 +155,9 @@ class BartModelTester:
|
||||
model = BartModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
|
||||
@@ -119,11 +119,10 @@ class TFBartModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@@ -158,9 +157,6 @@ def prepare_bart_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@@ -172,20 +168,11 @@ def prepare_bart_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -135,9 +135,7 @@ class BioGptModelTester:
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_biogpt_model_attention_mask_past(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
):
|
||||
def create_and_check_biogpt_model_attention_mask_past(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = BioGptModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -177,9 +175,7 @@ class BioGptModelTester:
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_biogpt_model_past_large_inputs(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
):
|
||||
def create_and_check_biogpt_model_past_large_inputs(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = BioGptModel(config=config).to(torch_device).eval()
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
@@ -213,7 +209,7 @@ class BioGptModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_forward_and_backwards(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
):
|
||||
model = BioGptForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
@@ -233,9 +229,7 @@ class BioGptModelTester:
|
||||
self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
|
||||
self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
|
||||
|
||||
def create_and_check_biogpt_for_token_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
):
|
||||
def create_and_check_biogpt_for_token_classification(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
config.num_labels = self.num_labels
|
||||
model = BioGptForTokenClassification(config)
|
||||
model.to(torch_device)
|
||||
|
||||
@@ -128,13 +128,10 @@ class GPTBigCodeModelTester:
|
||||
reorder_and_upcast_attn=reorder_and_upcast_attn,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
@@ -174,19 +171,19 @@ class GPTBigCodeModelTester:
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(result.past_key_values), config.n_layer)
|
||||
|
||||
def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -223,7 +220,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_gpt_bigcode_model_attention_mask_past(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, *args
|
||||
):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -265,7 +262,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_gpt_bigcode_model_past_large_inputs(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, *args
|
||||
):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -302,7 +299,7 @@ class GPTBigCodeModelTester:
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = GPTBigCodeForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -312,7 +309,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_forward_and_backwards(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
):
|
||||
model = GPTBigCodeForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
@@ -325,7 +322,7 @@ class GPTBigCodeModelTester:
|
||||
result.loss.backward()
|
||||
|
||||
def create_and_check_gpt_bigcode_for_sequence_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPTBigCodeForSequenceClassification(config)
|
||||
@@ -335,7 +332,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_gpt_bigcode_for_token_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPTBigCodeForTokenClassification(config)
|
||||
@@ -359,7 +356,6 @@ class GPTBigCodeModelTester:
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
@@ -370,7 +366,6 @@ class GPTBigCodeModelTester:
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -51,28 +51,16 @@ def prepare_m2m_100_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -166,10 +154,9 @@ class M2M100ModelTester:
|
||||
model = M2M100Model(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
|
||||
@@ -55,28 +55,16 @@ def prepare_mbart_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -158,10 +146,9 @@ class MBartModelTester:
|
||||
model = MBartModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
|
||||
@@ -107,11 +107,10 @@ class TFMBartModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
@@ -123,9 +122,6 @@ def prepare_mbart_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@@ -137,20 +133,11 @@ def prepare_mbart_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -76,27 +76,19 @@ def prepare_musicgen_decoder_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :]
|
||||
attention_mask = attention_mask.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
|
||||
if encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -467,9 +459,6 @@ def prepare_musicgen_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
labels=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
@@ -477,26 +466,11 @@ def prepare_musicgen_inputs_dict(
|
||||
-1, config.decoder.num_codebooks, decoder_input_ids.shape[-1]
|
||||
)[:, 0, :]
|
||||
decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(
|
||||
config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device
|
||||
)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(
|
||||
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
|
||||
)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(
|
||||
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
|
||||
)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
@@ -80,15 +80,12 @@ def prepare_musicgen_melody_decoder_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :]
|
||||
attention_mask = attention_mask.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
|
||||
if encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device)
|
||||
return {
|
||||
@@ -96,7 +93,6 @@ def prepare_musicgen_melody_decoder_inputs_dict(
|
||||
"attention_mask": attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -475,8 +471,6 @@ def prepare_musicgen_melody_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
labels=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
@@ -484,21 +478,11 @@ def prepare_musicgen_melody_inputs_dict(
|
||||
-1, config.decoder.num_codebooks, decoder_input_ids.shape[-1]
|
||||
)[:, 0, :]
|
||||
decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(
|
||||
config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device
|
||||
)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(
|
||||
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
|
||||
)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
@@ -52,15 +52,12 @@ def prepare_opt_inputs_dict(
|
||||
decoder_input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -156,10 +153,9 @@ class OPTModelTester:
|
||||
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@@ -187,7 +183,7 @@ class OPTModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
# test no attention_mask works
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
_, past_key_values = outputs.to_tuple()
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
|
||||
|
||||
@@ -62,25 +62,13 @@ def prepare_whisper_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = tf.where(decoder_input_ids != config.pad_token_id, 1, 0)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_features": input_features,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -350,9 +338,6 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("fp16 is not yet supported for TF models")
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
@@ -159,26 +159,14 @@ def prepare_whisper_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
# "input_ids": input_features,
|
||||
"input_features": input_features,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -3235,12 +3223,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue((eager_generated_ids[permutation_idx, :] == static_generated_ids).all())
|
||||
|
||||
|
||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
return {"input_features": input_features, "head_mask": head_mask}
|
||||
|
||||
|
||||
@require_torch
|
||||
class WhisperEncoderModelTester:
|
||||
def __init__(
|
||||
@@ -3314,10 +3296,7 @@ class WhisperEncoderModelTester:
|
||||
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length])
|
||||
|
||||
config = self.get_config()
|
||||
inputs_dict = prepare_whisper_encoder_inputs_dict(
|
||||
config,
|
||||
input_features=input_features,
|
||||
)
|
||||
inputs_dict = {"input_features": input_features}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@@ -3427,8 +3406,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
encoder_inputs = {"input_features": inputs["input_features"]}
|
||||
del inputs["input_features"]
|
||||
|
||||
if "head_mask" in inputs:
|
||||
encoder_inputs["head_mask"] = inputs["head_mask"]
|
||||
if "attention_mask" in inputs:
|
||||
encoder_inputs["attention_mask"] = inputs["attention_mask"]
|
||||
if "output_attentions" in inputs:
|
||||
@@ -3523,9 +3500,6 @@ class WhisperStandaloneDecoderModelTester:
|
||||
)
|
||||
|
||||
inputs_dict.pop("input_features")
|
||||
inputs_dict.pop("head_mask")
|
||||
inputs_dict.pop("decoder_head_mask")
|
||||
inputs_dict.pop("cross_attn_head_mask")
|
||||
|
||||
inputs_dict["attention_mask"] = inputs_dict.pop("decoder_attention_mask")
|
||||
inputs_dict["input_ids"] = inputs_dict.pop("decoder_input_ids")
|
||||
|
||||
Reference in New Issue
Block a user