Remove head mask in generative models (#35786)

* just squash into one commit

* delete print
This commit is contained in:
Raushan Turganbay
2025-05-15 10:44:19 +02:00
committed by GitHub
parent 0173a99e73
commit 955e61b0da
47 changed files with 103 additions and 294 deletions

View File

@@ -128,13 +128,10 @@ class GPTBigCodeModelTester:
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
return (
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
mc_token_ids,
sequence_labels,
@@ -174,19 +171,19 @@ class GPTBigCodeModelTester:
config.vocab_size = 300
return config
def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, token_type_ids, *args):
model = GPTBigCodeModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(len(result.past_key_values), config.n_layer)
def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, token_type_ids, *args):
model = GPTBigCodeModel(config=config)
model.to(torch_device)
model.eval()
@@ -223,7 +220,7 @@ class GPTBigCodeModelTester:
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_gpt_bigcode_model_attention_mask_past(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
self, config, input_ids, input_mask, token_type_ids, *args
):
model = GPTBigCodeModel(config=config)
model.to(torch_device)
@@ -265,7 +262,7 @@ class GPTBigCodeModelTester:
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_gpt_bigcode_model_past_large_inputs(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
self, config, input_ids, input_mask, token_type_ids, *args
):
model = GPTBigCodeModel(config=config)
model.to(torch_device)
@@ -302,7 +299,7 @@ class GPTBigCodeModelTester:
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_lm_head_model(self, config, input_ids, input_mask, token_type_ids, *args):
model = GPTBigCodeForCausalLM(config)
model.to(torch_device)
model.eval()
@@ -312,7 +309,7 @@ class GPTBigCodeModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPTBigCodeForCausalLM(config)
model.to(torch_device)
@@ -325,7 +322,7 @@ class GPTBigCodeModelTester:
result.loss.backward()
def create_and_check_gpt_bigcode_for_sequence_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
self, config, input_ids, input_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
config.num_labels = self.num_labels
model = GPTBigCodeForSequenceClassification(config)
@@ -335,7 +332,7 @@ class GPTBigCodeModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_gpt_bigcode_for_token_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
self, config, input_ids, input_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
config.num_labels = self.num_labels
model = GPTBigCodeForTokenClassification(config)
@@ -359,7 +356,6 @@ class GPTBigCodeModelTester:
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
mc_token_ids,
sequence_labels,
@@ -370,7 +366,6 @@ class GPTBigCodeModelTester:
inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"head_mask": head_mask,
}
return config, inputs_dict