From b47d9b2f8a75127f5bae4e1e430d0103b262b2e9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 20 Mar 2025 10:58:54 +0000 Subject: [PATCH] [generate] clarify docstrings: when to inherit `GenerationMixin` (#36605) --- src/transformers/generation/utils.py | 17 ++++++++++++++++- src/transformers/modeling_utils.py | 3 +++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 654edb1f95..aa6c8fb203 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -343,7 +343,22 @@ GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] class GenerationMixin: """ - A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. + A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes. + Inheriting from this class causes the model to have special generation-related behavior, such as loading a + `GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI. + + A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it + has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which + approximately shares the same interface to public methods like `generate`. Three examples: + - `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public + methods in the mixin; + - `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as + `GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls + `GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should + inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase; + - `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`. + However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case, + `BarkModel` shoud NOT inherit from `GenerationMixin`, as it breaks the `generate` interface. The class exposes [`~generation.GenerationMixin.generate`], which can be used for: - *greedy decoding* if `num_beams=1` and `do_sample=False` diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 120af2b842..a4b13ca0f2 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2178,6 +2178,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix """ Returns whether this model can generate sequences with `.generate()` from the `GenerationMixin`. + Under the hood, on classes where this function returns True, some generation-specific changes are triggered: + for instance, the model instance will have a populated `generation_config` attribute. + Returns: `bool`: Whether this model can generate sequences with `.generate()`. """