[generate] clarify docstrings: when to inherit GenerationMixin (#36605)
This commit is contained in:
@@ -343,7 +343,22 @@ GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
|
|||||||
|
|
||||||
class GenerationMixin:
|
class GenerationMixin:
|
||||||
"""
|
"""
|
||||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes.
|
||||||
|
Inheriting from this class causes the model to have special generation-related behavior, such as loading a
|
||||||
|
`GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI.
|
||||||
|
|
||||||
|
A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it
|
||||||
|
has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which
|
||||||
|
approximately shares the same interface to public methods like `generate`. Three examples:
|
||||||
|
- `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public
|
||||||
|
methods in the mixin;
|
||||||
|
- `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as
|
||||||
|
`GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls
|
||||||
|
`GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should
|
||||||
|
inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase;
|
||||||
|
- `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`.
|
||||||
|
However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case,
|
||||||
|
`BarkModel` shoud NOT inherit from `GenerationMixin`, as it breaks the `generate` interface.
|
||||||
|
|
||||||
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
|
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
|
||||||
- *greedy decoding* if `num_beams=1` and `do_sample=False`
|
- *greedy decoding* if `num_beams=1` and `do_sample=False`
|
||||||
|
|||||||
@@ -2178,6 +2178,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"""
|
"""
|
||||||
Returns whether this model can generate sequences with `.generate()` from the `GenerationMixin`.
|
Returns whether this model can generate sequences with `.generate()` from the `GenerationMixin`.
|
||||||
|
|
||||||
|
Under the hood, on classes where this function returns True, some generation-specific changes are triggered:
|
||||||
|
for instance, the model instance will have a populated `generation_config` attribute.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`bool`: Whether this model can generate sequences with `.generate()`.
|
`bool`: Whether this model can generate sequences with `.generate()`.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user