Fix mask creations of GPTNeoX and GPT2 (#31944)
* fix mask creation of gpt2 and gpt_neox caused by me * forgot the reshape of masks when shape > 2 * add tests for gpt neox and gpt2 * nit on a comment
This commit is contained in:
@@ -219,6 +219,36 @@ class GPTNeoXModelTester:
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_cached_forward_with_and_without_attention_mask(self, config, input_ids, *args):
|
||||
# Relevant issue: https://github.com/huggingface/transformers/issues/31943
|
||||
model = GPTNeoXModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# We want this for SDPA, eager works with a `None` attention mask
|
||||
assert (
|
||||
model.config._attn_implementation == "sdpa"
|
||||
), "This test assumes the model to have the SDPA implementation for its attention calculations."
|
||||
|
||||
# Prepare cache and non_cache input, needs a full attention mask
|
||||
cached_len = input_ids.shape[-1] // 2
|
||||
input_mask = torch.ones(size=input_ids.size()).to(torch_device)
|
||||
cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]}
|
||||
non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask}
|
||||
|
||||
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
|
||||
cache_outputs = model(**cache_inputs)
|
||||
full_outputs_with_attention_mask = model(
|
||||
**non_cache_inputs, past_key_values=cache_outputs.past_key_values
|
||||
).last_hidden_state
|
||||
full_outputs_without_attention_mask = model(
|
||||
non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values
|
||||
).last_hidden_state
|
||||
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(full_outputs_with_attention_mask, full_outputs_without_attention_mask, atol=1e-5)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, input_mask, token_labels = config_and_inputs
|
||||
@@ -300,6 +330,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_cached_forward_with_and_without_attention_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_cached_forward_with_and_without_attention_mask(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Feed forward chunking is not implemented")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user