[generate] skip compilation on cpu offload (#37709)

* skip compilation on cpu offload

* add test

* better logic

* docstring

* boolean logic

* add disk offload check

* warn users if compilation options are set but compilation doesn happen

* fix test

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Joao Gante
2025-04-24 14:08:17 +01:00
committed by GitHub
parent 7c62e69326
commit 8bdd4f2acd
4 changed files with 91 additions and 21 deletions

View File

@@ -5262,7 +5262,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
def loss_function(self, value):
self._loss_function = value
def get_compiled_call(self, compile_config: CompileConfig):
def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
@@ -5270,7 +5270,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Only reset it if not present or different from previous config
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
return self.__call__
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig()
if (
not hasattr(self, "_compiled_call")
or getattr(self, "_last_compile_config", default_config) != compile_config