diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index c6cb322e88..9e2cbf485c 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -327,7 +327,6 @@ We enable custom decoding methods through model repositories, assuming a specifi If a model repository holds a custom decoding method, the easiest way to try it out is to load the model and generate with it: - ```py from transformers import AutoModelForCausalLM, AutoTokenizer @@ -430,7 +429,7 @@ This is the core of your decoding method. It *must* contain a method named `gene > [!WARNING] > `generate.py` must be placed in a folder named `custom_generate`, and not at the root level of the repository. The file paths for this feature are hardcoded. -Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method. +Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method, with the exception of the arguments used to trigger the custom generation (`trust_remote_code` and `custom_generate`). This means your `generate` can have a mix of original and custom arguments (as well as a different output type) as shown below. diff --git a/docs/source/en/llm_tutorial.md b/docs/source/en/llm_tutorial.md index a191cdb463..1283e8b6a4 100644 --- a/docs/source/en/llm_tutorial.md +++ b/docs/source/en/llm_tutorial.md @@ -84,14 +84,17 @@ GenerationConfig { } ``` -You can customize [`~GenerationMixin.generate`] by overriding the parameters and values in [`GenerationConfig`]. Some of the most commonly adjusted parameters are [max_new_tokens](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.max_new_tokens), [num_beams](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.num_beams), [do_sample](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.do_sample), and [num_return_sequences](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.num_return_sequences). +You can customize [`~GenerationMixin.generate`] by overriding the parameters and values in [`GenerationConfig`]. See [this section below](#common-options) for commonly adjusted parameters. ```py # enable beam search sampling strategy model.generate(**inputs, num_beams=4, do_sample=True) ``` -[`~GenerationMixin.generate`] can also be extended with external libraries or custom code. The `logits_processor` parameter accepts custom [`LogitsProcessor`] instances for manipulating the next token probability distribution. `stopping_criteria` supports custom [`StoppingCriteria`] to stop text generation. Check out the [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo) for more examples of external [`~GenerationMixin.generate`]-compatible extensions. +[`~GenerationMixin.generate`] can also be extended with external libraries or custom code: +1. the `logits_processor` parameter accepts custom [`LogitsProcessor`] instances for manipulating the next token probability distribution; +2. the `stopping_criteria` parameters supports custom [`StoppingCriteria`] to stop text generation; +3. other custom generation methods can be loaded through the `custom_generate` flag ([docs](generation_strategies.md/#custom-decoding-methods)). Refer to the [Generation strategies](./generation_strategies) guide to learn more about search, sampling, and decoding strategies. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 914069ce6b..8ea1e7e760 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2347,9 +2347,15 @@ class GenerationMixin(ContinuousMixin): if custom_generate is not None: trust_remote_code = kwargs.pop("trust_remote_code", None) # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them: - # they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to - # methods from `GenerationMixin` through `model`. - global_keys_to_exclude = {"self", "kwargs"} + # they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to + # trigger the custom generation. They can access to methods from `GenerationMixin` through `model`. + global_keys_to_exclude = { + "self", + "kwargs", + "global_keys_to_exclude", + "trust_remote_code", + "custom_generate", + } generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude} generate_arguments.update(kwargs)