Remove head mask in generative models (#35786)
* just squash into one commit * delete print
This commit is contained in:
committed by
GitHub
parent
0173a99e73
commit
955e61b0da
@@ -128,13 +128,10 @@ class GPTBigCodeModelTester:
|
||||
reorder_and_upcast_attn=reorder_and_upcast_attn,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
@@ -174,19 +171,19 @@ class GPTBigCodeModelTester:
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(result.past_key_values), config.n_layer)
|
||||
|
||||
def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -223,7 +220,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_gpt_bigcode_model_attention_mask_past(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, *args
|
||||
):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -265,7 +262,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_gpt_bigcode_model_past_large_inputs(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, *args
|
||||
):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -302,7 +299,7 @@ class GPTBigCodeModelTester:
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = GPTBigCodeForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -312,7 +309,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_forward_and_backwards(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
):
|
||||
model = GPTBigCodeForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
@@ -325,7 +322,7 @@ class GPTBigCodeModelTester:
|
||||
result.loss.backward()
|
||||
|
||||
def create_and_check_gpt_bigcode_for_sequence_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPTBigCodeForSequenceClassification(config)
|
||||
@@ -335,7 +332,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_gpt_bigcode_for_token_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPTBigCodeForTokenClassification(config)
|
||||
@@ -359,7 +356,6 @@ class GPTBigCodeModelTester:
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
@@ -370,7 +366,6 @@ class GPTBigCodeModelTester:
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
Reference in New Issue
Block a user