[GPTNeoX] Flex Attention + Refactor (#34896)
* gpt neox flex attention + refactor * some formatting * small fix on dropout * add assertion on flex attn test * flaky ci :( * add head mask support * style * handle dtype, replace torch where * fixup flex with output attns * code review and several other fixes * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style * remove unnecessary comment * remove incorrect comment * make flex attn check more agnostic tor versions and centralized * change peft input dtype check to value since q and k could be affected by other stuff like RoPE * i forgor * flaky * code review and small fixes * Update src/transformers/models/gpt_neox/modeling_gpt_neox.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -459,6 +459,31 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_str, expected_output)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_flex_attn_gptneox(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
||||
for checkpointing in [True, False]:
|
||||
model = GPTNeoXForCausalLM.from_pretrained(
|
||||
"EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention"
|
||||
)
|
||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||
|
||||
if checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
else:
|
||||
model.gradient_checkpointing_disable()
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
|
||||
# The hub repo. is updated on 2023-04-04, resulting in poor outputs.
|
||||
# See: https://github.com/huggingface/transformers/pull/24193
|
||||
expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure"
|
||||
|
||||
output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
||||
output_str = tokenizer.batch_decode(output_ids)[0]
|
||||
|
||||
self.assertEqual(output_str, expected_output)
|
||||
|
||||
def pythia_integration_test(self):
|
||||
model_name_or_path = "EleutherAI/pythia-70m"
|
||||
model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user