[generate] fix default autocompile case on gpu (#37756)
This commit is contained in:
@@ -3430,7 +3430,8 @@ class GenerationMixin:
|
|||||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||||
|
|
||||||
model_forward = self.__call__
|
model_forward = self.__call__
|
||||||
if self._valid_auto_compile_criteria(model_kwargs, generation_config):
|
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
||||||
|
if compile_forward:
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||||
|
|
||||||
|
|||||||
@@ -5270,6 +5270,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
# Only reset it if not present or different from previous config
|
# Only reset it if not present or different from previous config
|
||||||
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
|
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
|
||||||
return self.__call__
|
return self.__call__
|
||||||
|
compile_config = compile_config or CompileConfig()
|
||||||
default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig()
|
default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig()
|
||||||
if (
|
if (
|
||||||
not hasattr(self, "_compiled_call")
|
not hasattr(self, "_compiled_call")
|
||||||
|
|||||||
Reference in New Issue
Block a user