Fix mask slicing for models with HybridCache (#35681)
* correctly slice * check mask * Update modular_gemma2.py * fix * add tests * fix typo * finally fix mask slicing * Finally correctly slice in all cases!! * add test for all attention functions * small fix in tests * trick around dynamo tracing issue * last update * more robust * kwargs propagation * make it explicit for checkpointing * apply modular
This commit is contained in:
@@ -324,3 +324,36 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
||||
)
|
||||
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
||||
|
||||
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
|
||||
@require_read_token
|
||||
def test_generation_beyond_sliding_window(self, attn_implementation: str):
|
||||
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
|
||||
EXPECTED_COMPLETIONS = [
|
||||
" the mountains, the lakes, the rivers, the waterfalls, the waterfalls, the waterfalls, the waterfalls",
|
||||
", green, yellow, orange, purple, pink, brown, black, white, grey, silver",
|
||||
]
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
||||
@@ -394,3 +394,36 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
|
||||
@require_read_token
|
||||
def test_generation_beyond_sliding_window(self, attn_implementation: str):
|
||||
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
model_id = "google/gemma-2-2b"
|
||||
EXPECTED_COMPLETIONS = [
|
||||
" the people, the food, the culture, the history, the music, the art, the architecture",
|
||||
", green, yellow, orange, purple, pink, brown, black, white, gray, silver",
|
||||
]
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
||||
Reference in New Issue
Block a user