From 166e823f770477b17988020b2476a796d49836a6 Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Fri, 20 Jun 2025 18:36:57 +0200 Subject: [PATCH] Fix custom generate from local directory (#38916) Fix custom generate from local directory: 1. Create parent dirs before copying files (custom_generate dir) 2. Correctly copy relative imports to the submodule file. 3. Update docs. --- docs/source/en/generation_strategies.md | 10 +++++++++- src/transformers/dynamic_module_utils.py | 3 ++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 9e2cbf485c..6453669f68 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -468,9 +468,17 @@ def generate(model, input_ids, generation_config=None, left_padding=None, **kwar Follow the recommended practices below to ensure your custom decoding method works as expected. - Feel free to reuse the logic for validation and input preparation in the original [`~GenerationMixin.generate`]. - Pin the `transformers` version in the requirements if you use any private method/attribute in `model`. -- You can add other files in the `custom_generate` folder, and use relative imports. - Consider adding model validation, input validation, or even a separate test file to help users sanity-check your code in their environment. +Your custom `generate` method can relative import code from the `custom_generate` folder. For example, if you have a `utils.py` file, you can import it like this: + +```py +from .utils import some_function +``` + +Only relative imports from the same-level `custom_generate` folder are supported. Parent/sibling folder imports are not valid. The `custom_generate` argument also works locally with any directory that contains a `custom_generate` structure. This is the recommended workflow for developing your custom decoding method. + + #### requirements.txt You can optionally specify additional Python requirements in a `requirements.txt` file inside the `custom_generate` folder. These are checked at runtime and an exception will be thrown if they're missing, nudging users to update their environment accordingly. diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 6a88859e0a..7a498721a9 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -402,10 +402,11 @@ def get_cached_module_file( if not (submodule_path / module_file).exists() or not filecmp.cmp( resolved_module_file, str(submodule_path / module_file) ): + (submodule_path / module_file).parent.mkdir(parents=True, exist_ok=True) shutil.copy(resolved_module_file, submodule_path / module_file) importlib.invalidate_caches() for module_needed in modules_needed: - module_needed = f"{module_needed}.py" + module_needed = Path(module_file).parent / f"{module_needed}.py" module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed) if not (submodule_path / module_needed).exists() or not filecmp.cmp( module_needed_file, str(submodule_path / module_needed)