[custom_generate] don't forward custom_generate and trust_remote_code (#38304)
* prevent infinite loops * docs * more links to custom generation methods
This commit is contained in:
@@ -327,7 +327,6 @@ We enable custom decoding methods through model repositories, assuming a specifi
|
|||||||
|
|
||||||
If a model repository holds a custom decoding method, the easiest way to try it out is to load the model and generate with it:
|
If a model repository holds a custom decoding method, the easiest way to try it out is to load the model and generate with it:
|
||||||
|
|
||||||
<!-- TODO before merging: 1) better repo name (use a `generate-community` org?) 2) prettify the repo -->
|
|
||||||
```py
|
```py
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
@@ -430,7 +429,7 @@ This is the core of your decoding method. It *must* contain a method named `gene
|
|||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> `generate.py` must be placed in a folder named `custom_generate`, and not at the root level of the repository. The file paths for this feature are hardcoded.
|
> `generate.py` must be placed in a folder named `custom_generate`, and not at the root level of the repository. The file paths for this feature are hardcoded.
|
||||||
|
|
||||||
Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method.
|
Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method, with the exception of the arguments used to trigger the custom generation (`trust_remote_code` and `custom_generate`).
|
||||||
|
|
||||||
This means your `generate` can have a mix of original and custom arguments (as well as a different output type) as shown below.
|
This means your `generate` can have a mix of original and custom arguments (as well as a different output type) as shown below.
|
||||||
|
|
||||||
|
|||||||
@@ -84,14 +84,17 @@ GenerationConfig {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
You can customize [`~GenerationMixin.generate`] by overriding the parameters and values in [`GenerationConfig`]. Some of the most commonly adjusted parameters are [max_new_tokens](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.max_new_tokens), [num_beams](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.num_beams), [do_sample](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.do_sample), and [num_return_sequences](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.num_return_sequences).
|
You can customize [`~GenerationMixin.generate`] by overriding the parameters and values in [`GenerationConfig`]. See [this section below](#common-options) for commonly adjusted parameters.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
# enable beam search sampling strategy
|
# enable beam search sampling strategy
|
||||||
model.generate(**inputs, num_beams=4, do_sample=True)
|
model.generate(**inputs, num_beams=4, do_sample=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
[`~GenerationMixin.generate`] can also be extended with external libraries or custom code. The `logits_processor` parameter accepts custom [`LogitsProcessor`] instances for manipulating the next token probability distribution. `stopping_criteria` supports custom [`StoppingCriteria`] to stop text generation. Check out the [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo) for more examples of external [`~GenerationMixin.generate`]-compatible extensions.
|
[`~GenerationMixin.generate`] can also be extended with external libraries or custom code:
|
||||||
|
1. the `logits_processor` parameter accepts custom [`LogitsProcessor`] instances for manipulating the next token probability distribution;
|
||||||
|
2. the `stopping_criteria` parameters supports custom [`StoppingCriteria`] to stop text generation;
|
||||||
|
3. other custom generation methods can be loaded through the `custom_generate` flag ([docs](generation_strategies.md/#custom-decoding-methods)).
|
||||||
|
|
||||||
Refer to the [Generation strategies](./generation_strategies) guide to learn more about search, sampling, and decoding strategies.
|
Refer to the [Generation strategies](./generation_strategies) guide to learn more about search, sampling, and decoding strategies.
|
||||||
|
|
||||||
|
|||||||
@@ -2347,9 +2347,15 @@ class GenerationMixin(ContinuousMixin):
|
|||||||
if custom_generate is not None:
|
if custom_generate is not None:
|
||||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||||
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
|
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
|
||||||
# they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to
|
# they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
|
||||||
# methods from `GenerationMixin` through `model`.
|
# trigger the custom generation. They can access to methods from `GenerationMixin` through `model`.
|
||||||
global_keys_to_exclude = {"self", "kwargs"}
|
global_keys_to_exclude = {
|
||||||
|
"self",
|
||||||
|
"kwargs",
|
||||||
|
"global_keys_to_exclude",
|
||||||
|
"trust_remote_code",
|
||||||
|
"custom_generate",
|
||||||
|
}
|
||||||
generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
|
generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
|
||||||
generate_arguments.update(kwargs)
|
generate_arguments.update(kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user