[generate] Run custom generation code from the Hub (#36405)

* mvp

* remove trust_remote_code

* generate_from_hub

* handle requirements; docs

* english

* doc PR suggestions

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* changed remote code path to generate/generate.py

* model repo has custom generate -> override base generate

* check for proper inheritance

* some doc updates (missing: tag-related docs)

* update docs to model repo

* nit

* nit

* nits

* Update src/transformers/dynamic_module_utils.py

* Apply suggestions from code review

* Update docs/source/en/generation_strategies.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* trust remote code is required

* use new import utils for requirements version parsing

* use  org examples

* add tests

* Apply suggestions from code review

Co-authored-by: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com>

* ascii file structure; tag instructions on readme.md

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com>
This commit is contained in:
Joao Gante
2025-05-15 10:35:54 +01:00
committed by GitHub
parent 955e61b0da
commit 0e0e5c1044
6 changed files with 522 additions and 97 deletions

View File

@@ -4104,6 +4104,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
@@ -4113,7 +4114,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
_ = kwargs.pop("trust_remote_code", None)
_ = kwargs.pop("mirror", None)
_ = kwargs.pop("_fast_init", True)
_ = kwargs.pop("low_cpu_mem_usage", None)
@@ -4591,30 +4591,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# If it is a model with generation capabilities, attempt to load the generation config
# If it is a model with generation capabilities, attempt to load generation files (generation config,
# custom generate function)
if model.can_generate() and generation_config is not None:
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
elif model.can_generate() and pretrained_model_name_or_path is not None:
repo_loading_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"subfolder": subfolder,
**kwargs,
}
# Load generation config
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
**repo_loading_kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(model, "load_custom_generate"):
try:
custom_generate = model.load_custom_generate(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
)
model.generate = functools.partial(custom_generate, model=model)
except OSError: # there is no custom generate function
pass
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
# harm performances)