[generate] clarify docstrings: when to inherit GenerationMixin (#36605)
This commit is contained in:
@@ -343,7 +343,22 @@ GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
|
||||
|
||||
class GenerationMixin:
|
||||
"""
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes.
|
||||
Inheriting from this class causes the model to have special generation-related behavior, such as loading a
|
||||
`GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI.
|
||||
|
||||
A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it
|
||||
has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which
|
||||
approximately shares the same interface to public methods like `generate`. Three examples:
|
||||
- `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public
|
||||
methods in the mixin;
|
||||
- `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as
|
||||
`GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls
|
||||
`GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should
|
||||
inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase;
|
||||
- `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`.
|
||||
However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case,
|
||||
`BarkModel` shoud NOT inherit from `GenerationMixin`, as it breaks the `generate` interface.
|
||||
|
||||
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
|
||||
- *greedy decoding* if `num_beams=1` and `do_sample=False`
|
||||
|
||||
@@ -2178,6 +2178,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"""
|
||||
Returns whether this model can generate sequences with `.generate()` from the `GenerationMixin`.
|
||||
|
||||
Under the hood, on classes where this function returns True, some generation-specific changes are triggered:
|
||||
for instance, the model instance will have a populated `generation_config` attribute.
|
||||
|
||||
Returns:
|
||||
`bool`: Whether this model can generate sequences with `.generate()`.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user