Fix mask slicing for models with HybridCache (#35681)

* correctly slice

* check mask

* Update modular_gemma2.py

* fix

* add tests

* fix typo

* finally fix mask slicing

* Finally correctly slice in all cases!!

* add test for all attention functions

* small fix in tests

* trick around dynamo tracing issue

* last update

* more robust

* kwargs propagation

* make it explicit for checkpointing

* apply modular
This commit is contained in:
Cyril Vallez
2025-01-28 14:35:00 +01:00
committed by GitHub
parent b764c20b09
commit 3f860dba55
6 changed files with 232 additions and 22 deletions

View File

@@ -324,3 +324,36 @@ class Cohere2IntegrationTest(unittest.TestCase):
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
@require_read_token
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
EXPECTED_COMPLETIONS = [
" the mountains, the lakes, the rivers, the waterfalls, the waterfalls, the waterfalls, the waterfalls",
", green, yellow, orange, purple, pink, brown, black, white, grey, silver",
]
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
output_text = tokenizer.batch_decode(out)
self.assertEqual(output_text, EXPECTED_COMPLETIONS)

View File

@@ -394,3 +394,36 @@ class Gemma2IntegrationTest(unittest.TestCase):
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
self.assertEqual(output_text, EXPECTED_TEXTS)
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
@require_read_token
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
model_id = "google/gemma-2-2b"
EXPECTED_COMPLETIONS = [
" the people, the food, the culture, the history, the music, the art, the architecture",
", green, yellow, orange, purple, pink, brown, black, white, gray, silver",
]
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
output_text = tokenizer.batch_decode(out)
self.assertEqual(output_text, EXPECTED_COMPLETIONS)