Generate: get generation mode from the generation config instance 🧼 (#29441)
This commit is contained in:
@@ -37,6 +37,9 @@ like token streaming.
|
||||
- from_pretrained
|
||||
- from_model_config
|
||||
- save_pretrained
|
||||
- update
|
||||
- validate
|
||||
- get_generation_mode
|
||||
|
||||
## GenerationMixin
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["GenerationConfig"],
|
||||
"configuration_utils": ["GenerationConfig", "GenerationMode"],
|
||||
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
||||
}
|
||||
|
||||
@@ -172,7 +172,7 @@ else:
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_utils import GenerationConfig
|
||||
from .configuration_utils import GenerationConfig, GenerationMode
|
||||
from .streamers import TextIteratorStreamer, TextStreamer
|
||||
|
||||
try:
|
||||
|
||||
@@ -18,12 +18,13 @@ import copy
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from .. import __version__
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..utils import (
|
||||
GENERATION_CONFIG_NAME,
|
||||
ExplicitEnum,
|
||||
PushToHubMixin,
|
||||
cached_file,
|
||||
download_url,
|
||||
@@ -33,10 +34,31 @@ from ..utils import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
|
||||
|
||||
|
||||
class GenerationMode(ExplicitEnum):
|
||||
"""
|
||||
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
|
||||
"""
|
||||
|
||||
# Non-beam methods
|
||||
CONTRASTIVE_SEARCH = "contrastive_search"
|
||||
GREEDY_SEARCH = "greedy_search"
|
||||
SAMPLE = "sample"
|
||||
ASSISTED_GENERATION = "assisted_generation"
|
||||
# Beam methods
|
||||
BEAM_SEARCH = "beam_search"
|
||||
BEAM_SAMPLE = "beam_sample"
|
||||
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
|
||||
GROUP_BEAM_SEARCH = "group_beam_search"
|
||||
|
||||
|
||||
class GenerationConfig(PushToHubMixin):
|
||||
# no-format
|
||||
r"""
|
||||
@@ -376,13 +398,65 @@ class GenerationConfig(PushToHubMixin):
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
|
||||
|
||||
def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = None) -> GenerationMode:
|
||||
"""
|
||||
Returns the generation mode triggered by the [`GenerationConfig`] instance.
|
||||
|
||||
Arg:
|
||||
assistant_model (`PreTrainedModel`, *optional*):
|
||||
The assistant model to be used for assisted generation. If set, the generation mode will be
|
||||
assisted generation.
|
||||
|
||||
Returns:
|
||||
`GenerationMode`: The generation mode triggered by the instance.
|
||||
"""
|
||||
# TODO joao: find out a way of not depending on external fields (e.g. `assistant_model`), then make this a
|
||||
# property and part of the `__repr__`
|
||||
if self.constraints is not None or self.force_words_ids is not None:
|
||||
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
|
||||
elif self.num_beams == 1:
|
||||
if self.do_sample is False:
|
||||
if (
|
||||
self.top_k is not None
|
||||
and self.top_k > 1
|
||||
and self.penalty_alpha is not None
|
||||
and self.penalty_alpha > 0
|
||||
):
|
||||
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
|
||||
else:
|
||||
generation_mode = GenerationMode.GREEDY_SEARCH
|
||||
else:
|
||||
generation_mode = GenerationMode.SAMPLE
|
||||
else:
|
||||
if self.num_beam_groups > 1:
|
||||
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
|
||||
elif self.do_sample is True:
|
||||
generation_mode = GenerationMode.BEAM_SAMPLE
|
||||
else:
|
||||
generation_mode = GenerationMode.BEAM_SEARCH
|
||||
|
||||
# Assisted generation may extend some generation modes
|
||||
if assistant_model is not None or self.prompt_lookup_num_tokens is not None:
|
||||
if generation_mode in ("greedy_search", "sample"):
|
||||
generation_mode = GenerationMode.ASSISTED_GENERATION
|
||||
else:
|
||||
raise ValueError(
|
||||
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
||||
"is only supported with Greedy Search and Sample."
|
||||
)
|
||||
return generation_mode
|
||||
|
||||
def validate(self, is_init=False):
|
||||
"""
|
||||
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
||||
of parameterization that can be detected as incorrect from the configuration instance alone.
|
||||
|
||||
Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
|
||||
model, such as parameters related to the generation length.
|
||||
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
|
||||
other inputs and/or the model, such as parameters related to the generation length.
|
||||
|
||||
Arg:
|
||||
is_init (`bool`, *optional*, defaults to `False`):
|
||||
Whether the validation is performed during the initialization of the instance.
|
||||
"""
|
||||
|
||||
# Validation of individual attributes
|
||||
|
||||
@@ -34,7 +34,7 @@ from ..models.auto import (
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
)
|
||||
from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
|
||||
from ..utils import ModelOutput, is_accelerate_available, logging
|
||||
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .candidate_generator import (
|
||||
@@ -45,7 +45,7 @@ from .candidate_generator import (
|
||||
_prepare_attention_mask,
|
||||
_prepare_token_type_ids,
|
||||
)
|
||||
from .configuration_utils import GenerationConfig
|
||||
from .configuration_utils import GenerationConfig, GenerationMode
|
||||
from .logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
@@ -325,23 +325,6 @@ GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDec
|
||||
GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
|
||||
|
||||
|
||||
class GenerationMode(ExplicitEnum):
|
||||
"""
|
||||
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
|
||||
"""
|
||||
|
||||
# Non-beam methods
|
||||
CONTRASTIVE_SEARCH = "contrastive_search"
|
||||
GREEDY_SEARCH = "greedy_search"
|
||||
SAMPLE = "sample"
|
||||
ASSISTED_GENERATION = "assisted_generation"
|
||||
# Beam methods
|
||||
BEAM_SEARCH = "beam_search"
|
||||
BEAM_SAMPLE = "beam_sample"
|
||||
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
|
||||
GROUP_BEAM_SEARCH = "group_beam_search"
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
"""
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
||||
@@ -764,46 +747,6 @@ class GenerationMixin:
|
||||
warpers.append(LogitNormalization())
|
||||
return warpers
|
||||
|
||||
def _get_generation_mode(
|
||||
self, generation_config: GenerationConfig, assistant_model: Optional["PreTrainedModel"]
|
||||
) -> GenerationMode:
|
||||
"""
|
||||
Returns the generation mode triggered by a [`GenerationConfig`] instance.
|
||||
"""
|
||||
if generation_config.constraints is not None or generation_config.force_words_ids is not None:
|
||||
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
|
||||
elif generation_config.num_beams == 1:
|
||||
if generation_config.do_sample is False:
|
||||
if (
|
||||
generation_config.top_k is not None
|
||||
and generation_config.top_k > 1
|
||||
and generation_config.penalty_alpha is not None
|
||||
and generation_config.penalty_alpha > 0
|
||||
):
|
||||
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
|
||||
else:
|
||||
generation_mode = GenerationMode.GREEDY_SEARCH
|
||||
else:
|
||||
generation_mode = GenerationMode.SAMPLE
|
||||
else:
|
||||
if generation_config.num_beam_groups > 1:
|
||||
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
|
||||
elif generation_config.do_sample is True:
|
||||
generation_mode = GenerationMode.BEAM_SAMPLE
|
||||
else:
|
||||
generation_mode = GenerationMode.BEAM_SEARCH
|
||||
|
||||
# Assisted generation may extend some generation modes
|
||||
if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None:
|
||||
if generation_mode in ("greedy_search", "sample"):
|
||||
generation_mode = GenerationMode.ASSISTED_GENERATION
|
||||
else:
|
||||
raise ValueError(
|
||||
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
||||
"is only supported with Greedy Search and Sample."
|
||||
)
|
||||
return generation_mode
|
||||
|
||||
def _get_logits_processor(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
@@ -1474,7 +1417,7 @@ class GenerationMixin:
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
# 7. determine generation mode
|
||||
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
||||
generation_mode = generation_config.get_generation_mode(assistant_model)
|
||||
|
||||
if streamer is not None and (generation_config.num_beams > 1):
|
||||
raise ValueError(
|
||||
|
||||
@@ -24,6 +24,7 @@ from parameterized import parameterized
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import AutoConfig, GenerationConfig
|
||||
from transformers.generation import GenerationMode
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
@@ -202,6 +203,23 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
|
||||
|
||||
def test_generation_mode(self):
|
||||
"""Tests that the `get_generation_mode` method is working as expected."""
|
||||
config = GenerationConfig()
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.GREEDY_SEARCH)
|
||||
|
||||
config = GenerationConfig(do_sample=True)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.SAMPLE)
|
||||
|
||||
config = GenerationConfig(num_beams=2)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.BEAM_SEARCH)
|
||||
|
||||
config = GenerationConfig(top_k=10, do_sample=False, penalty_alpha=0.6)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.CONTRASTIVE_SEARCH)
|
||||
|
||||
config = GenerationConfig()
|
||||
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user