Generate: use GenerationConfig as the basis for .generate() parametrization (#20388)

* generate from config mvp

* fix failing tests

* max_time test

* Load default gen config at model load time; Update docs

* further documentation; add tests

* adapt rag to the new structure

* handle models not instantiated with from_pretained (like in tests)

* better default generation config

* add can_generate fn

* handle legacy use case of ad hoc model config changes

* initialize gen config from config in individual methods, if gen config is none

* fix _get_decoder_start_token_id when called outside GenerationMixin

* correct model config load order (set attr > model config > decoder config)

* update rag to match latest changes

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* load gen config from model config in model.from_pretrained

* fix can_generate fn

* handle generate calls without a previous from_pretrained (e.g. tests)

* add legacy behavior (and a warning)

* lower logger severity

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Joao Gante
2022-12-15 18:27:20 +00:00
committed by GitHub
parent b1706f6908
commit 4bc723f87d
6 changed files with 694 additions and 733 deletions

View File

@@ -18,12 +18,79 @@ Each framework has a generate method for auto-regressive text generation impleme
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
<!--- TODO: add a brief description of GenerationConfig (with examples) when it becomes usable with generate --->
Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`]
class instance. Please refer to this class for the complete list of generation parameters, which control the behavior
of the generation method.
All models have a default generation configuration that will be used if you don't provide one. If you have a loaded
model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd
like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and
store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty.
```python
from transformers import AutoModelForCausalLM, GenerationConfig
model = AutoModelForCausalLM.from_pretrained("my_account/my_model")
# Inspect the default generation configuration
print(model.generation_config)
# Set a new default generation configuration
generation_config = GenerationConfig(
max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
)
generation_config.save_pretrained("my_account/my_model", push_to_hub=True)
```
<Tip>
If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that
default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running
into resource limitations. Make sure you double-check the defaults in the documentation.
</Tip>
You can also store several generation parametrizations in a single directory, making use of the `config_file_name`
argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to
store several generation configurations for a single model (e.g. one for creative text generation with sampling, and
other for summarization with beam search).
```python
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
translation_generation_config = GenerationConfig(
num_beams=4,
early_stopping=True,
decoder_start_token_id=0,
eos_token_id=model.config.eos_token_id,
pad_token=model.config.pad_token_id,
)
# If you were working on a model for which your had the right Hub permissions, you could store a named generation
# config as follows
translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True)
# You could then use the named generation config file to parameterize generation
generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json")
inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
outputs = model.generate(**inputs, generation_config=generation_config)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
# ['Les fichiers de configuration sont faciles à utiliser !']
```
Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you
wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each
framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies
to parameterize it.
## GenerationConfig
[[autodoc]] generation.GenerationConfig
- from_pretrained
- from_model_config
- save_pretrained
## GenerationMixin