From bff4313b37960751d88ddba0a9fdb9ca5d32a64e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 4 Aug 2023 13:35:10 +0100 Subject: [PATCH] Generate: get generation mode as an enum (#25292) --- src/transformers/generation/utils.py | 133 ++++++++++++++------------- 1 file changed, 68 insertions(+), 65 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index aa1f377b28..75b2c4f714 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -33,7 +33,7 @@ from ..models.auto import ( MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) -from ..utils import ModelOutput, logging +from ..utils import ExplicitEnum, ModelOutput, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .configuration_utils import GenerationConfig @@ -468,6 +468,23 @@ ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, Contrasti GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput] +class GenerationMode(ExplicitEnum): + """ + Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method. + """ + + # Non-beam methods + CONTRASTIVE_SEARCH = "contrastive_search" + GREEDY_SEARCH = "greedy_search" + SAMPLE = "sample" + ASSISTED_GENERATION = "assisted_generation" + # Beam methods + BEAM_SEARCH = "beam_search" + BEAM_SAMPLE = "beam_sample" + CONSTRAINED_BEAM_SEARCH = "constrained_beam_search" + GROUP_BEAM_SEARCH = "group_beam_search" + + class GenerationMixin: """ A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. @@ -829,6 +846,46 @@ class GenerationMixin: warpers.append(LogitNormalization()) return warpers + def _get_generation_mode( + self, generation_config: GenerationConfig, assistant_model: Optional["PreTrainedModel"] + ) -> GenerationMode: + """ + Returns the generation mode triggered by a [`GenerationConfig`] instance. + """ + if generation_config.constraints is not None or generation_config.force_words_ids is not None: + generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH + elif generation_config.num_beams == 1: + if generation_config.do_sample is False: + if ( + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 + ): + generation_mode = GenerationMode.CONTRASTIVE_SEARCH + else: + generation_mode = GenerationMode.GREEDY_SEARCH + else: + generation_mode = GenerationMode.SAMPLE + else: + if generation_config.num_beam_groups > 1: + generation_mode = GenerationMode.GROUP_BEAM_SEARCH + elif generation_config.do_sample is True: + generation_mode = GenerationMode.BEAM_SAMPLE + else: + generation_mode = GenerationMode.BEAM_SEARCH + + # Assisted generation may extend some generation modes + if assistant_model is not None: + if generation_mode in ("greedy_search", "sample"): + generation_mode = GenerationMode.ASSISTED_GENERATION + else: + raise ValueError( + "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " + "is only supported with Greedy Search and Sample." + ) + return generation_mode + def _get_logits_processor( self, generation_config: GenerationConfig, @@ -1422,65 +1479,11 @@ class GenerationMixin: ) # 7. determine generation mode - is_constraint_gen_mode = ( - generation_config.constraints is not None or generation_config.force_words_ids is not None - ) - - is_contrastive_search_gen_mode = ( - (generation_config.num_beams == 1) - and generation_config.top_k is not None - and generation_config.top_k > 1 - and generation_config.do_sample is False - and generation_config.penalty_alpha is not None - and generation_config.penalty_alpha > 0 - ) - - is_greedy_gen_mode = ( - (generation_config.num_beams == 1) - and (generation_config.num_beam_groups == 1) - and generation_config.do_sample is False - and not is_constraint_gen_mode - and not is_contrastive_search_gen_mode - ) - is_sample_gen_mode = ( - (generation_config.num_beams == 1) - and (generation_config.num_beam_groups == 1) - and generation_config.do_sample is True - and not is_constraint_gen_mode - and not is_contrastive_search_gen_mode - ) - is_beam_gen_mode = ( - (generation_config.num_beams > 1) - and (generation_config.num_beam_groups == 1) - and generation_config.do_sample is False - and not is_constraint_gen_mode - and not is_contrastive_search_gen_mode - ) - is_beam_sample_gen_mode = ( - (generation_config.num_beams > 1) - and (generation_config.num_beam_groups == 1) - and generation_config.do_sample is True - and not is_constraint_gen_mode - and not is_contrastive_search_gen_mode - ) - is_group_beam_gen_mode = ( - (generation_config.num_beams > 1) - and (generation_config.num_beam_groups > 1) - and not is_constraint_gen_mode - and not is_contrastive_search_gen_mode - ) - is_assisted_gen_mode = False - if assistant_model is not None: - if not (is_greedy_gen_mode or is_sample_gen_mode): - raise ValueError( - "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " - "is only supported with Greedy Search and Sample." - ) - is_assisted_gen_mode = True + generation_mode = self._get_generation_mode(generation_config, assistant_model) if generation_config.num_beam_groups > generation_config.num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") - if is_group_beam_gen_mode and generation_config.do_sample is True: + if generation_mode == GenerationMode.GROUP_BEAM_SEARCH and generation_config.do_sample is True: raise ValueError( "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." ) @@ -1515,7 +1518,7 @@ class GenerationMixin: generation_config=generation_config, stopping_criteria=stopping_criteria ) # 10. go into different generation modes - if is_assisted_gen_mode: + if generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_config.num_return_sequences > 1: raise ValueError( "num_return_sequences has to be 1 when doing assisted generate, " @@ -1553,7 +1556,7 @@ class GenerationMixin: streamer=streamer, **model_kwargs, ) - if is_greedy_gen_mode: + if generation_mode == GenerationMode.GREEDY_SEARCH: if generation_config.num_return_sequences > 1: raise ValueError( "num_return_sequences has to be 1 when doing greedy search, " @@ -1574,7 +1577,7 @@ class GenerationMixin: **model_kwargs, ) - elif is_contrastive_search_gen_mode: + elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: if generation_config.num_return_sequences > 1: raise ValueError( "num_return_sequences has to be 1 when doing contrastive search, " @@ -1599,7 +1602,7 @@ class GenerationMixin: **model_kwargs, ) - elif is_sample_gen_mode: + elif generation_mode == GenerationMode.SAMPLE: # 11. prepare logits warper logits_warper = self._get_logits_warper(generation_config) @@ -1626,7 +1629,7 @@ class GenerationMixin: **model_kwargs, ) - elif is_beam_gen_mode: + elif generation_mode == GenerationMode.BEAM_SEARCH: if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") @@ -1664,7 +1667,7 @@ class GenerationMixin: **model_kwargs, ) - elif is_beam_sample_gen_mode: + elif generation_mode == GenerationMode.BEAM_SAMPLE: # 11. prepare logits warper logits_warper = self._get_logits_warper(generation_config) @@ -1703,7 +1706,7 @@ class GenerationMixin: **model_kwargs, ) - elif is_group_beam_gen_mode: + elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH: if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") @@ -1754,7 +1757,7 @@ class GenerationMixin: **model_kwargs, ) - elif is_constraint_gen_mode: + elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH: if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")