From 4d64c3859308a79dc7d9b8cafa2be039a77b2267 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 24 Apr 2025 15:08:38 +0100 Subject: [PATCH] [generate] fix default autocompile case on gpu (#37756) --- src/transformers/generation/utils.py | 3 ++- src/transformers/modeling_utils.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6dc9acf8a0..2ae16408b8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3430,7 +3430,8 @@ class GenerationMixin: model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) model_forward = self.__call__ - if self._valid_auto_compile_criteria(model_kwargs, generation_config): + compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) + if compile_forward: os.environ["TOKENIZERS_PARALLELISM"] = "0" model_forward = self.get_compiled_call(generation_config.compile_config) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index adac0890e6..e73862b54e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5270,6 +5270,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Only reset it if not present or different from previous config if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT return self.__call__ + compile_config = compile_config or CompileConfig() default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig() if ( not hasattr(self, "_compiled_call")