Fix mask creations of GPTNeoX and GPT2 (#31944)
* fix mask creation of gpt2 and gpt_neox caused by me * forgot the reshape of masks when shape > 2 * add tests for gpt neox and gpt2 * nit on a comment
This commit is contained in:
@@ -426,6 +426,36 @@ class GPT2ModelTester:
|
||||
self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
|
||||
self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
|
||||
|
||||
def create_and_check_cached_forward_with_and_without_attention_mask(self, config, input_ids, *args):
|
||||
# Relevant issue: https://github.com/huggingface/transformers/issues/31943
|
||||
model = GPT2Model(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# We want this for SDPA, eager works with a `None` attention mask
|
||||
assert (
|
||||
model.config._attn_implementation == "sdpa"
|
||||
), "This test assumes the model to have the SDPA implementation for its attention calculations."
|
||||
|
||||
# Prepare cache and non_cache input, needs a full attention mask
|
||||
cached_len = input_ids.shape[-1] // 2
|
||||
input_mask = torch.ones(size=input_ids.size()).to(torch_device)
|
||||
cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]}
|
||||
non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask}
|
||||
|
||||
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
|
||||
cache_outputs = model(**cache_inputs)
|
||||
full_outputs_with_attention_mask = model(
|
||||
**non_cache_inputs, past_key_values=cache_outputs.past_key_values
|
||||
).last_hidden_state
|
||||
full_outputs_without_attention_mask = model(
|
||||
non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values
|
||||
).last_hidden_state
|
||||
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(full_outputs_with_attention_mask, full_outputs_without_attention_mask, atol=1e-5)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
@@ -570,6 +600,10 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_weight_initialization(*config_and_inputs)
|
||||
|
||||
def test_cached_forward_with_and_without_attention_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_cached_forward_with_and_without_attention_mask(*config_and_inputs)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user