Generate: get generation mode as an enum (#25292)
This commit is contained in:
@@ -33,7 +33,7 @@ from ..models.auto import (
|
|||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
)
|
)
|
||||||
from ..utils import ModelOutput, logging
|
from ..utils import ExplicitEnum, ModelOutput, logging
|
||||||
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
from .configuration_utils import GenerationConfig
|
from .configuration_utils import GenerationConfig
|
||||||
@@ -468,6 +468,23 @@ ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, Contrasti
|
|||||||
GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput]
|
GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput]
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationMode(ExplicitEnum):
|
||||||
|
"""
|
||||||
|
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Non-beam methods
|
||||||
|
CONTRASTIVE_SEARCH = "contrastive_search"
|
||||||
|
GREEDY_SEARCH = "greedy_search"
|
||||||
|
SAMPLE = "sample"
|
||||||
|
ASSISTED_GENERATION = "assisted_generation"
|
||||||
|
# Beam methods
|
||||||
|
BEAM_SEARCH = "beam_search"
|
||||||
|
BEAM_SAMPLE = "beam_sample"
|
||||||
|
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
|
||||||
|
GROUP_BEAM_SEARCH = "group_beam_search"
|
||||||
|
|
||||||
|
|
||||||
class GenerationMixin:
|
class GenerationMixin:
|
||||||
"""
|
"""
|
||||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
||||||
@@ -829,6 +846,46 @@ class GenerationMixin:
|
|||||||
warpers.append(LogitNormalization())
|
warpers.append(LogitNormalization())
|
||||||
return warpers
|
return warpers
|
||||||
|
|
||||||
|
def _get_generation_mode(
|
||||||
|
self, generation_config: GenerationConfig, assistant_model: Optional["PreTrainedModel"]
|
||||||
|
) -> GenerationMode:
|
||||||
|
"""
|
||||||
|
Returns the generation mode triggered by a [`GenerationConfig`] instance.
|
||||||
|
"""
|
||||||
|
if generation_config.constraints is not None or generation_config.force_words_ids is not None:
|
||||||
|
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
|
||||||
|
elif generation_config.num_beams == 1:
|
||||||
|
if generation_config.do_sample is False:
|
||||||
|
if (
|
||||||
|
generation_config.top_k is not None
|
||||||
|
and generation_config.top_k > 1
|
||||||
|
and generation_config.penalty_alpha is not None
|
||||||
|
and generation_config.penalty_alpha > 0
|
||||||
|
):
|
||||||
|
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
|
||||||
|
else:
|
||||||
|
generation_mode = GenerationMode.GREEDY_SEARCH
|
||||||
|
else:
|
||||||
|
generation_mode = GenerationMode.SAMPLE
|
||||||
|
else:
|
||||||
|
if generation_config.num_beam_groups > 1:
|
||||||
|
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
|
||||||
|
elif generation_config.do_sample is True:
|
||||||
|
generation_mode = GenerationMode.BEAM_SAMPLE
|
||||||
|
else:
|
||||||
|
generation_mode = GenerationMode.BEAM_SEARCH
|
||||||
|
|
||||||
|
# Assisted generation may extend some generation modes
|
||||||
|
if assistant_model is not None:
|
||||||
|
if generation_mode in ("greedy_search", "sample"):
|
||||||
|
generation_mode = GenerationMode.ASSISTED_GENERATION
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
||||||
|
"is only supported with Greedy Search and Sample."
|
||||||
|
)
|
||||||
|
return generation_mode
|
||||||
|
|
||||||
def _get_logits_processor(
|
def _get_logits_processor(
|
||||||
self,
|
self,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
@@ -1422,65 +1479,11 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 7. determine generation mode
|
# 7. determine generation mode
|
||||||
is_constraint_gen_mode = (
|
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
||||||
generation_config.constraints is not None or generation_config.force_words_ids is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
is_contrastive_search_gen_mode = (
|
|
||||||
(generation_config.num_beams == 1)
|
|
||||||
and generation_config.top_k is not None
|
|
||||||
and generation_config.top_k > 1
|
|
||||||
and generation_config.do_sample is False
|
|
||||||
and generation_config.penalty_alpha is not None
|
|
||||||
and generation_config.penalty_alpha > 0
|
|
||||||
)
|
|
||||||
|
|
||||||
is_greedy_gen_mode = (
|
|
||||||
(generation_config.num_beams == 1)
|
|
||||||
and (generation_config.num_beam_groups == 1)
|
|
||||||
and generation_config.do_sample is False
|
|
||||||
and not is_constraint_gen_mode
|
|
||||||
and not is_contrastive_search_gen_mode
|
|
||||||
)
|
|
||||||
is_sample_gen_mode = (
|
|
||||||
(generation_config.num_beams == 1)
|
|
||||||
and (generation_config.num_beam_groups == 1)
|
|
||||||
and generation_config.do_sample is True
|
|
||||||
and not is_constraint_gen_mode
|
|
||||||
and not is_contrastive_search_gen_mode
|
|
||||||
)
|
|
||||||
is_beam_gen_mode = (
|
|
||||||
(generation_config.num_beams > 1)
|
|
||||||
and (generation_config.num_beam_groups == 1)
|
|
||||||
and generation_config.do_sample is False
|
|
||||||
and not is_constraint_gen_mode
|
|
||||||
and not is_contrastive_search_gen_mode
|
|
||||||
)
|
|
||||||
is_beam_sample_gen_mode = (
|
|
||||||
(generation_config.num_beams > 1)
|
|
||||||
and (generation_config.num_beam_groups == 1)
|
|
||||||
and generation_config.do_sample is True
|
|
||||||
and not is_constraint_gen_mode
|
|
||||||
and not is_contrastive_search_gen_mode
|
|
||||||
)
|
|
||||||
is_group_beam_gen_mode = (
|
|
||||||
(generation_config.num_beams > 1)
|
|
||||||
and (generation_config.num_beam_groups > 1)
|
|
||||||
and not is_constraint_gen_mode
|
|
||||||
and not is_contrastive_search_gen_mode
|
|
||||||
)
|
|
||||||
is_assisted_gen_mode = False
|
|
||||||
if assistant_model is not None:
|
|
||||||
if not (is_greedy_gen_mode or is_sample_gen_mode):
|
|
||||||
raise ValueError(
|
|
||||||
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
|
||||||
"is only supported with Greedy Search and Sample."
|
|
||||||
)
|
|
||||||
is_assisted_gen_mode = True
|
|
||||||
|
|
||||||
if generation_config.num_beam_groups > generation_config.num_beams:
|
if generation_config.num_beam_groups > generation_config.num_beams:
|
||||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||||
if is_group_beam_gen_mode and generation_config.do_sample is True:
|
if generation_mode == GenerationMode.GROUP_BEAM_SEARCH and generation_config.do_sample is True:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||||
)
|
)
|
||||||
@@ -1515,7 +1518,7 @@ class GenerationMixin:
|
|||||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||||
)
|
)
|
||||||
# 10. go into different generation modes
|
# 10. go into different generation modes
|
||||||
if is_assisted_gen_mode:
|
if generation_mode == GenerationMode.ASSISTED_GENERATION:
|
||||||
if generation_config.num_return_sequences > 1:
|
if generation_config.num_return_sequences > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"num_return_sequences has to be 1 when doing assisted generate, "
|
"num_return_sequences has to be 1 when doing assisted generate, "
|
||||||
@@ -1553,7 +1556,7 @@ class GenerationMixin:
|
|||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
if is_greedy_gen_mode:
|
if generation_mode == GenerationMode.GREEDY_SEARCH:
|
||||||
if generation_config.num_return_sequences > 1:
|
if generation_config.num_return_sequences > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"num_return_sequences has to be 1 when doing greedy search, "
|
"num_return_sequences has to be 1 when doing greedy search, "
|
||||||
@@ -1574,7 +1577,7 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif is_contrastive_search_gen_mode:
|
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
|
||||||
if generation_config.num_return_sequences > 1:
|
if generation_config.num_return_sequences > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"num_return_sequences has to be 1 when doing contrastive search, "
|
"num_return_sequences has to be 1 when doing contrastive search, "
|
||||||
@@ -1599,7 +1602,7 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif is_sample_gen_mode:
|
elif generation_mode == GenerationMode.SAMPLE:
|
||||||
# 11. prepare logits warper
|
# 11. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(generation_config)
|
logits_warper = self._get_logits_warper(generation_config)
|
||||||
|
|
||||||
@@ -1626,7 +1629,7 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif is_beam_gen_mode:
|
elif generation_mode == GenerationMode.BEAM_SEARCH:
|
||||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||||
|
|
||||||
@@ -1664,7 +1667,7 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif is_beam_sample_gen_mode:
|
elif generation_mode == GenerationMode.BEAM_SAMPLE:
|
||||||
# 11. prepare logits warper
|
# 11. prepare logits warper
|
||||||
logits_warper = self._get_logits_warper(generation_config)
|
logits_warper = self._get_logits_warper(generation_config)
|
||||||
|
|
||||||
@@ -1703,7 +1706,7 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif is_group_beam_gen_mode:
|
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
|
||||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||||
|
|
||||||
@@ -1754,7 +1757,7 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif is_constraint_gen_mode:
|
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
|
||||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user