[GPTNeoX] Flex Attention + Refactor (#34896)

* gpt neox flex attention + refactor

* some formatting

* small fix on dropout

* add assertion on flex attn test

* flaky ci :(

* add head mask support

* style

* handle dtype, replace torch where

* fixup flex with output attns

* code review and several other fixes

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* style

* remove unnecessary comment

* remove incorrect comment

* make flex attn check more agnostic tor versions and centralized

* change peft input dtype check to value since q and k could be affected by other stuff like RoPE

* i forgor

* flaky

* code review and small fixes

* Update src/transformers/models/gpt_neox/modeling_gpt_neox.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Anton Vlasjuk
2024-12-04 14:48:28 +01:00
committed by GitHub
parent accb7204f9
commit 46df859975
6 changed files with 372 additions and 250 deletions

View File

@@ -459,6 +459,31 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase):
self.assertEqual(output_str, expected_output)
@slow
def test_lm_generate_flex_attn_gptneox(self):
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
for checkpointing in [True, False]:
model = GPTNeoXForCausalLM.from_pretrained(
"EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention"
)
self.assertTrue(model.config._attn_implementation == "flex_attention")
if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device)
inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
# The hub repo. is updated on 2023-04-04, resulting in poor outputs.
# See: https://github.com/huggingface/transformers/pull/24193
expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure"
output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20)
output_str = tokenizer.batch_decode(output_ids)[0]
self.assertEqual(output_str, expected_output)
def pythia_integration_test(self):
model_name_or_path = "EleutherAI/pythia-70m"
model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device)