>3-5x faster torch.compile forward compilation for autoregressive decoder models (#32227)
* draft * apply changes to all relevant archs * rerun ci - check_docstrings.py failing? * fix docstring * move 2D->4D mask creation to modeling file * repo consistency * fix the batch size = 1 case - calling contiguous is not enough * nit * style * propagate to gemma/gemma-2 * prepare inputs for gemma generation * implement test and tiny fix in gemma2 * Update src/transformers/models/bloom/modeling_bloom.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix copies * ci pass * fix gemma's test_compile_static_cache tests * flacky * retrigger ci --------- Co-authored-by: sanchit-gandhi <sanchit@huggingface.co> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -816,7 +816,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
# Dynamic Cache
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) # Both GPU architectures have the same output
|
||||
|
||||
# Static Cache
|
||||
generated_ids = model.generate(
|
||||
|
||||
Reference in New Issue
Block a user