[generate] skip compilation on cpu offload (#37709)
* skip compilation on cpu offload * add test * better logic * docstring * boolean logic * add disk offload check * warn users if compilation options are set but compilation doesn happen * fix test --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -5262,7 +5262,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
def loss_function(self, value):
|
||||
self._loss_function = value
|
||||
|
||||
def get_compiled_call(self, compile_config: CompileConfig):
|
||||
def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
|
||||
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
|
||||
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
|
||||
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
|
||||
@@ -5270,7 +5270,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Only reset it if not present or different from previous config
|
||||
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
|
||||
return self.__call__
|
||||
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
|
||||
default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig()
|
||||
if (
|
||||
not hasattr(self, "_compiled_call")
|
||||
or getattr(self, "_last_compile_config", default_config) != compile_config
|
||||
|
||||
Reference in New Issue
Block a user