[generate] Run custom generation code from the Hub (#36405)
* mvp * remove trust_remote_code * generate_from_hub * handle requirements; docs * english * doc PR suggestions * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * changed remote code path to generate/generate.py * model repo has custom generate -> override base generate * check for proper inheritance * some doc updates (missing: tag-related docs) * update docs to model repo * nit * nit * nits * Update src/transformers/dynamic_module_utils.py * Apply suggestions from code review * Update docs/source/en/generation_strategies.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * trust remote code is required * use new import utils for requirements version parsing * use org examples * add tests * Apply suggestions from code review Co-authored-by: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> * ascii file structure; tag instructions on readme.md --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com>
This commit is contained in:
@@ -4104,6 +4104,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
|
||||
@@ -4113,7 +4114,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
_ = kwargs.pop("resume_download", None)
|
||||
_ = kwargs.pop("trust_remote_code", None)
|
||||
_ = kwargs.pop("mirror", None)
|
||||
_ = kwargs.pop("_fast_init", True)
|
||||
_ = kwargs.pop("low_cpu_mem_usage", None)
|
||||
@@ -4591,30 +4591,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
# If it is a model with generation capabilities, attempt to load the generation config
|
||||
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
||||
# custom generate function)
|
||||
if model.can_generate() and generation_config is not None:
|
||||
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
|
||||
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
|
||||
elif model.can_generate() and pretrained_model_name_or_path is not None:
|
||||
repo_loading_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
**kwargs,
|
||||
}
|
||||
# Load generation config
|
||||
try:
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
**kwargs,
|
||||
**repo_loading_kwargs,
|
||||
)
|
||||
except OSError:
|
||||
logger.info(
|
||||
"Generation config file not found, using a generation config created from the model config."
|
||||
)
|
||||
pass
|
||||
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
|
||||
if hasattr(model, "load_custom_generate"):
|
||||
try:
|
||||
custom_generate = model.load_custom_generate(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
|
||||
)
|
||||
model.generate = functools.partial(custom_generate, model=model)
|
||||
except OSError: # there is no custom generate function
|
||||
pass
|
||||
|
||||
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
|
||||
# harm performances)
|
||||
|
||||
Reference in New Issue
Block a user