Generate: get generation mode as an enum (#25292)
This commit is contained in:
@@ -33,7 +33,7 @@ from ..models.auto import (
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
)
|
||||
from ..utils import ModelOutput, logging
|
||||
from ..utils import ExplicitEnum, ModelOutput, logging
|
||||
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .configuration_utils import GenerationConfig
|
||||
@@ -468,6 +468,23 @@ ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, Contrasti
|
||||
GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput]
|
||||
|
||||
|
||||
class GenerationMode(ExplicitEnum):
|
||||
"""
|
||||
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
|
||||
"""
|
||||
|
||||
# Non-beam methods
|
||||
CONTRASTIVE_SEARCH = "contrastive_search"
|
||||
GREEDY_SEARCH = "greedy_search"
|
||||
SAMPLE = "sample"
|
||||
ASSISTED_GENERATION = "assisted_generation"
|
||||
# Beam methods
|
||||
BEAM_SEARCH = "beam_search"
|
||||
BEAM_SAMPLE = "beam_sample"
|
||||
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
|
||||
GROUP_BEAM_SEARCH = "group_beam_search"
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
"""
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
||||
@@ -829,6 +846,46 @@ class GenerationMixin:
|
||||
warpers.append(LogitNormalization())
|
||||
return warpers
|
||||
|
||||
def _get_generation_mode(
|
||||
self, generation_config: GenerationConfig, assistant_model: Optional["PreTrainedModel"]
|
||||
) -> GenerationMode:
|
||||
"""
|
||||
Returns the generation mode triggered by a [`GenerationConfig`] instance.
|
||||
"""
|
||||
if generation_config.constraints is not None or generation_config.force_words_ids is not None:
|
||||
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
|
||||
elif generation_config.num_beams == 1:
|
||||
if generation_config.do_sample is False:
|
||||
if (
|
||||
generation_config.top_k is not None
|
||||
and generation_config.top_k > 1
|
||||
and generation_config.penalty_alpha is not None
|
||||
and generation_config.penalty_alpha > 0
|
||||
):
|
||||
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
|
||||
else:
|
||||
generation_mode = GenerationMode.GREEDY_SEARCH
|
||||
else:
|
||||
generation_mode = GenerationMode.SAMPLE
|
||||
else:
|
||||
if generation_config.num_beam_groups > 1:
|
||||
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
|
||||
elif generation_config.do_sample is True:
|
||||
generation_mode = GenerationMode.BEAM_SAMPLE
|
||||
else:
|
||||
generation_mode = GenerationMode.BEAM_SEARCH
|
||||
|
||||
# Assisted generation may extend some generation modes
|
||||
if assistant_model is not None:
|
||||
if generation_mode in ("greedy_search", "sample"):
|
||||
generation_mode = GenerationMode.ASSISTED_GENERATION
|
||||
else:
|
||||
raise ValueError(
|
||||
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
||||
"is only supported with Greedy Search and Sample."
|
||||
)
|
||||
return generation_mode
|
||||
|
||||
def _get_logits_processor(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
@@ -1422,65 +1479,11 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 7. determine generation mode
|
||||
is_constraint_gen_mode = (
|
||||
generation_config.constraints is not None or generation_config.force_words_ids is not None
|
||||
)
|
||||
|
||||
is_contrastive_search_gen_mode = (
|
||||
(generation_config.num_beams == 1)
|
||||
and generation_config.top_k is not None
|
||||
and generation_config.top_k > 1
|
||||
and generation_config.do_sample is False
|
||||
and generation_config.penalty_alpha is not None
|
||||
and generation_config.penalty_alpha > 0
|
||||
)
|
||||
|
||||
is_greedy_gen_mode = (
|
||||
(generation_config.num_beams == 1)
|
||||
and (generation_config.num_beam_groups == 1)
|
||||
and generation_config.do_sample is False
|
||||
and not is_constraint_gen_mode
|
||||
and not is_contrastive_search_gen_mode
|
||||
)
|
||||
is_sample_gen_mode = (
|
||||
(generation_config.num_beams == 1)
|
||||
and (generation_config.num_beam_groups == 1)
|
||||
and generation_config.do_sample is True
|
||||
and not is_constraint_gen_mode
|
||||
and not is_contrastive_search_gen_mode
|
||||
)
|
||||
is_beam_gen_mode = (
|
||||
(generation_config.num_beams > 1)
|
||||
and (generation_config.num_beam_groups == 1)
|
||||
and generation_config.do_sample is False
|
||||
and not is_constraint_gen_mode
|
||||
and not is_contrastive_search_gen_mode
|
||||
)
|
||||
is_beam_sample_gen_mode = (
|
||||
(generation_config.num_beams > 1)
|
||||
and (generation_config.num_beam_groups == 1)
|
||||
and generation_config.do_sample is True
|
||||
and not is_constraint_gen_mode
|
||||
and not is_contrastive_search_gen_mode
|
||||
)
|
||||
is_group_beam_gen_mode = (
|
||||
(generation_config.num_beams > 1)
|
||||
and (generation_config.num_beam_groups > 1)
|
||||
and not is_constraint_gen_mode
|
||||
and not is_contrastive_search_gen_mode
|
||||
)
|
||||
is_assisted_gen_mode = False
|
||||
if assistant_model is not None:
|
||||
if not (is_greedy_gen_mode or is_sample_gen_mode):
|
||||
raise ValueError(
|
||||
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
||||
"is only supported with Greedy Search and Sample."
|
||||
)
|
||||
is_assisted_gen_mode = True
|
||||
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
||||
|
||||
if generation_config.num_beam_groups > generation_config.num_beams:
|
||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||
if is_group_beam_gen_mode and generation_config.do_sample is True:
|
||||
if generation_mode == GenerationMode.GROUP_BEAM_SEARCH and generation_config.do_sample is True:
|
||||
raise ValueError(
|
||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||
)
|
||||
@@ -1515,7 +1518,7 @@ class GenerationMixin:
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
)
|
||||
# 10. go into different generation modes
|
||||
if is_assisted_gen_mode:
|
||||
if generation_mode == GenerationMode.ASSISTED_GENERATION:
|
||||
if generation_config.num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
"num_return_sequences has to be 1 when doing assisted generate, "
|
||||
@@ -1553,7 +1556,7 @@ class GenerationMixin:
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
if is_greedy_gen_mode:
|
||||
if generation_mode == GenerationMode.GREEDY_SEARCH:
|
||||
if generation_config.num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
"num_return_sequences has to be 1 when doing greedy search, "
|
||||
@@ -1574,7 +1577,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_contrastive_search_gen_mode:
|
||||
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
|
||||
if generation_config.num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
"num_return_sequences has to be 1 when doing contrastive search, "
|
||||
@@ -1599,7 +1602,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_sample_gen_mode:
|
||||
elif generation_mode == GenerationMode.SAMPLE:
|
||||
# 11. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
|
||||
@@ -1626,7 +1629,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_beam_gen_mode:
|
||||
elif generation_mode == GenerationMode.BEAM_SEARCH:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
@@ -1664,7 +1667,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_beam_sample_gen_mode:
|
||||
elif generation_mode == GenerationMode.BEAM_SAMPLE:
|
||||
# 11. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
|
||||
@@ -1703,7 +1706,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_group_beam_gen_mode:
|
||||
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
@@ -1754,7 +1757,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_constraint_gen_mode:
|
||||
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user