[generate] clarify docstrings: when to inherit GenerationMixin (#36605)

This commit is contained in:
Joao Gante
2025-03-20 10:58:54 +00:00
committed by GitHub
parent 8e97b44087
commit b47d9b2f8a
2 changed files with 19 additions and 1 deletions

View File

@@ -343,7 +343,22 @@ GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
class GenerationMixin:
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes.
Inheriting from this class causes the model to have special generation-related behavior, such as loading a
`GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI.
A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it
has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which
approximately shares the same interface to public methods like `generate`. Three examples:
- `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public
methods in the mixin;
- `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as
`GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls
`GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should
inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase;
- `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`.
However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case,
`BarkModel` shoud NOT inherit from `GenerationMixin`, as it breaks the `generate` interface.
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
- *greedy decoding* if `num_beams=1` and `do_sample=False`

View File

@@ -2178,6 +2178,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
Returns whether this model can generate sequences with `.generate()` from the `GenerationMixin`.
Under the hood, on classes where this function returns True, some generation-specific changes are triggered:
for instance, the model instance will have a populated `generation_config` attribute.
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""