Fix mask creations of GPTNeoX and GPT2 (#31944)

* fix mask creation of gpt2 and gpt_neox caused by me

* forgot the reshape of masks when shape > 2

* add tests for gpt neox and gpt2

* nit on a comment
This commit is contained in:
Anton Vlasjuk
2024-07-23 10:11:12 +02:00
committed by GitHub
parent 2782aadae2
commit 605f3245dc
4 changed files with 97 additions and 31 deletions

View File

@@ -219,6 +219,36 @@ class GPTNeoXModelTester:
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_cached_forward_with_and_without_attention_mask(self, config, input_ids, *args):
# Relevant issue: https://github.com/huggingface/transformers/issues/31943
model = GPTNeoXModel(config)
model.to(torch_device)
model.eval()
# We want this for SDPA, eager works with a `None` attention mask
assert (
model.config._attn_implementation == "sdpa"
), "This test assumes the model to have the SDPA implementation for its attention calculations."
# Prepare cache and non_cache input, needs a full attention mask
cached_len = input_ids.shape[-1] // 2
input_mask = torch.ones(size=input_ids.size()).to(torch_device)
cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]}
non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask}
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
cache_outputs = model(**cache_inputs)
full_outputs_with_attention_mask = model(
**non_cache_inputs, past_key_values=cache_outputs.past_key_values
).last_hidden_state
full_outputs_without_attention_mask = model(
non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values
).last_hidden_state
self.parent.assertTrue(
torch.allclose(full_outputs_with_attention_mask, full_outputs_without_attention_mask, atol=1e-5)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask, token_labels = config_and_inputs
@@ -300,6 +330,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_cached_forward_with_and_without_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_cached_forward_with_and_without_attention_mask(*config_and_inputs)
@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass