From 858545047c05a35fde437b2ada3a901844cd1e60 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 10 Mar 2025 09:24:26 +0000 Subject: [PATCH] [`HybridCache`] disable automatic compilation (#36620) --- src/transformers/cache_utils.py | 4 +++- tests/generation/test_utils.py | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6a98b1b2ff..94a68bf0df 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1602,7 +1602,9 @@ class HybridCache(Cache): ``` """ - is_compileable = True + # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert + # ALL changes from the PR that commented the line below when reactivating it. + # is_compileable = True # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. @deprecate_kwarg("layer_device_map", version="4.52.0") diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f66dca2125..3a4171161f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2118,6 +2118,9 @@ class GenerationTesterMixin: Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ """ + # Monkey-patching the HybridCache at test-time to continue testing compilation support + HybridCache.is_compileable = True + for model_class in self.all_generative_model_classes: if not model_class._supports_static_cache: self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") @@ -2214,6 +2217,9 @@ class GenerationTesterMixin: Tests that all optional outputs are behaving as expected when compilation is triggered. In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered. """ + # Monkey-patching the HybridCache at test-time to continue testing compilation support + HybridCache.is_compileable = True + for model_class in self.all_generative_model_classes: if not model_class._supports_static_cache: self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")