Generate: get generation mode from the generation config instance 🧼 (#29441)
This commit is contained in:
@@ -37,6 +37,9 @@ like token streaming.
|
|||||||
- from_pretrained
|
- from_pretrained
|
||||||
- from_model_config
|
- from_model_config
|
||||||
- save_pretrained
|
- save_pretrained
|
||||||
|
- update
|
||||||
|
- validate
|
||||||
|
- get_generation_mode
|
||||||
|
|
||||||
## GenerationMixin
|
## GenerationMixin
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_utils": ["GenerationConfig"],
|
"configuration_utils": ["GenerationConfig", "GenerationMode"],
|
||||||
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
"streamers": ["TextIteratorStreamer", "TextStreamer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,7 +172,7 @@ else:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_utils import GenerationConfig
|
from .configuration_utils import GenerationConfig, GenerationMode
|
||||||
from .streamers import TextIteratorStreamer, TextStreamer
|
from .streamers import TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -18,12 +18,13 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from ..configuration_utils import PretrainedConfig
|
from ..configuration_utils import PretrainedConfig
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
GENERATION_CONFIG_NAME,
|
GENERATION_CONFIG_NAME,
|
||||||
|
ExplicitEnum,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
cached_file,
|
cached_file,
|
||||||
download_url,
|
download_url,
|
||||||
@@ -33,10 +34,31 @@ from ..utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
|
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationMode(ExplicitEnum):
|
||||||
|
"""
|
||||||
|
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Non-beam methods
|
||||||
|
CONTRASTIVE_SEARCH = "contrastive_search"
|
||||||
|
GREEDY_SEARCH = "greedy_search"
|
||||||
|
SAMPLE = "sample"
|
||||||
|
ASSISTED_GENERATION = "assisted_generation"
|
||||||
|
# Beam methods
|
||||||
|
BEAM_SEARCH = "beam_search"
|
||||||
|
BEAM_SAMPLE = "beam_sample"
|
||||||
|
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
|
||||||
|
GROUP_BEAM_SEARCH = "group_beam_search"
|
||||||
|
|
||||||
|
|
||||||
class GenerationConfig(PushToHubMixin):
|
class GenerationConfig(PushToHubMixin):
|
||||||
# no-format
|
# no-format
|
||||||
r"""
|
r"""
|
||||||
@@ -376,13 +398,65 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
|
return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
|
||||||
|
|
||||||
|
def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = None) -> GenerationMode:
|
||||||
|
"""
|
||||||
|
Returns the generation mode triggered by the [`GenerationConfig`] instance.
|
||||||
|
|
||||||
|
Arg:
|
||||||
|
assistant_model (`PreTrainedModel`, *optional*):
|
||||||
|
The assistant model to be used for assisted generation. If set, the generation mode will be
|
||||||
|
assisted generation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`GenerationMode`: The generation mode triggered by the instance.
|
||||||
|
"""
|
||||||
|
# TODO joao: find out a way of not depending on external fields (e.g. `assistant_model`), then make this a
|
||||||
|
# property and part of the `__repr__`
|
||||||
|
if self.constraints is not None or self.force_words_ids is not None:
|
||||||
|
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
|
||||||
|
elif self.num_beams == 1:
|
||||||
|
if self.do_sample is False:
|
||||||
|
if (
|
||||||
|
self.top_k is not None
|
||||||
|
and self.top_k > 1
|
||||||
|
and self.penalty_alpha is not None
|
||||||
|
and self.penalty_alpha > 0
|
||||||
|
):
|
||||||
|
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
|
||||||
|
else:
|
||||||
|
generation_mode = GenerationMode.GREEDY_SEARCH
|
||||||
|
else:
|
||||||
|
generation_mode = GenerationMode.SAMPLE
|
||||||
|
else:
|
||||||
|
if self.num_beam_groups > 1:
|
||||||
|
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
|
||||||
|
elif self.do_sample is True:
|
||||||
|
generation_mode = GenerationMode.BEAM_SAMPLE
|
||||||
|
else:
|
||||||
|
generation_mode = GenerationMode.BEAM_SEARCH
|
||||||
|
|
||||||
|
# Assisted generation may extend some generation modes
|
||||||
|
if assistant_model is not None or self.prompt_lookup_num_tokens is not None:
|
||||||
|
if generation_mode in ("greedy_search", "sample"):
|
||||||
|
generation_mode = GenerationMode.ASSISTED_GENERATION
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
||||||
|
"is only supported with Greedy Search and Sample."
|
||||||
|
)
|
||||||
|
return generation_mode
|
||||||
|
|
||||||
def validate(self, is_init=False):
|
def validate(self, is_init=False):
|
||||||
"""
|
"""
|
||||||
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
||||||
of parameterization that can be detected as incorrect from the configuration instance alone.
|
of parameterization that can be detected as incorrect from the configuration instance alone.
|
||||||
|
|
||||||
Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
|
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
|
||||||
model, such as parameters related to the generation length.
|
other inputs and/or the model, such as parameters related to the generation length.
|
||||||
|
|
||||||
|
Arg:
|
||||||
|
is_init (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether the validation is performed during the initialization of the instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Validation of individual attributes
|
# Validation of individual attributes
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from ..models.auto import (
|
|||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
)
|
)
|
||||||
from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
|
from ..utils import ModelOutput, is_accelerate_available, logging
|
||||||
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
from .candidate_generator import (
|
from .candidate_generator import (
|
||||||
@@ -45,7 +45,7 @@ from .candidate_generator import (
|
|||||||
_prepare_attention_mask,
|
_prepare_attention_mask,
|
||||||
_prepare_token_type_ids,
|
_prepare_token_type_ids,
|
||||||
)
|
)
|
||||||
from .configuration_utils import GenerationConfig
|
from .configuration_utils import GenerationConfig, GenerationMode
|
||||||
from .logits_process import (
|
from .logits_process import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
EncoderRepetitionPenaltyLogitsProcessor,
|
EncoderRepetitionPenaltyLogitsProcessor,
|
||||||
@@ -325,23 +325,6 @@ GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDec
|
|||||||
GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
|
GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
|
||||||
|
|
||||||
|
|
||||||
class GenerationMode(ExplicitEnum):
|
|
||||||
"""
|
|
||||||
Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Non-beam methods
|
|
||||||
CONTRASTIVE_SEARCH = "contrastive_search"
|
|
||||||
GREEDY_SEARCH = "greedy_search"
|
|
||||||
SAMPLE = "sample"
|
|
||||||
ASSISTED_GENERATION = "assisted_generation"
|
|
||||||
# Beam methods
|
|
||||||
BEAM_SEARCH = "beam_search"
|
|
||||||
BEAM_SAMPLE = "beam_sample"
|
|
||||||
CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
|
|
||||||
GROUP_BEAM_SEARCH = "group_beam_search"
|
|
||||||
|
|
||||||
|
|
||||||
class GenerationMixin:
|
class GenerationMixin:
|
||||||
"""
|
"""
|
||||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
||||||
@@ -764,46 +747,6 @@ class GenerationMixin:
|
|||||||
warpers.append(LogitNormalization())
|
warpers.append(LogitNormalization())
|
||||||
return warpers
|
return warpers
|
||||||
|
|
||||||
def _get_generation_mode(
|
|
||||||
self, generation_config: GenerationConfig, assistant_model: Optional["PreTrainedModel"]
|
|
||||||
) -> GenerationMode:
|
|
||||||
"""
|
|
||||||
Returns the generation mode triggered by a [`GenerationConfig`] instance.
|
|
||||||
"""
|
|
||||||
if generation_config.constraints is not None or generation_config.force_words_ids is not None:
|
|
||||||
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
|
|
||||||
elif generation_config.num_beams == 1:
|
|
||||||
if generation_config.do_sample is False:
|
|
||||||
if (
|
|
||||||
generation_config.top_k is not None
|
|
||||||
and generation_config.top_k > 1
|
|
||||||
and generation_config.penalty_alpha is not None
|
|
||||||
and generation_config.penalty_alpha > 0
|
|
||||||
):
|
|
||||||
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
|
|
||||||
else:
|
|
||||||
generation_mode = GenerationMode.GREEDY_SEARCH
|
|
||||||
else:
|
|
||||||
generation_mode = GenerationMode.SAMPLE
|
|
||||||
else:
|
|
||||||
if generation_config.num_beam_groups > 1:
|
|
||||||
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
|
|
||||||
elif generation_config.do_sample is True:
|
|
||||||
generation_mode = GenerationMode.BEAM_SAMPLE
|
|
||||||
else:
|
|
||||||
generation_mode = GenerationMode.BEAM_SEARCH
|
|
||||||
|
|
||||||
# Assisted generation may extend some generation modes
|
|
||||||
if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None:
|
|
||||||
if generation_mode in ("greedy_search", "sample"):
|
|
||||||
generation_mode = GenerationMode.ASSISTED_GENERATION
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
|
||||||
"is only supported with Greedy Search and Sample."
|
|
||||||
)
|
|
||||||
return generation_mode
|
|
||||||
|
|
||||||
def _get_logits_processor(
|
def _get_logits_processor(
|
||||||
self,
|
self,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
@@ -1474,7 +1417,7 @@ class GenerationMixin:
|
|||||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||||
|
|
||||||
# 7. determine generation mode
|
# 7. determine generation mode
|
||||||
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
generation_mode = generation_config.get_generation_mode(assistant_model)
|
||||||
|
|
||||||
if streamer is not None and (generation_config.num_beams > 1):
|
if streamer is not None and (generation_config.num_beams > 1):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from parameterized import parameterized
|
|||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
from transformers import AutoConfig, GenerationConfig
|
from transformers import AutoConfig, GenerationConfig
|
||||||
|
from transformers.generation import GenerationMode
|
||||||
from transformers.testing_utils import TOKEN, USER, is_staging_test
|
from transformers.testing_utils import TOKEN, USER, is_staging_test
|
||||||
|
|
||||||
|
|
||||||
@@ -202,6 +203,23 @@ class GenerationConfigTest(unittest.TestCase):
|
|||||||
self.assertEqual(len(captured_warnings), 0)
|
self.assertEqual(len(captured_warnings), 0)
|
||||||
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
|
self.assertTrue(len(os.listdir(tmp_dir)) == 1)
|
||||||
|
|
||||||
|
def test_generation_mode(self):
|
||||||
|
"""Tests that the `get_generation_mode` method is working as expected."""
|
||||||
|
config = GenerationConfig()
|
||||||
|
self.assertEqual(config.get_generation_mode(), GenerationMode.GREEDY_SEARCH)
|
||||||
|
|
||||||
|
config = GenerationConfig(do_sample=True)
|
||||||
|
self.assertEqual(config.get_generation_mode(), GenerationMode.SAMPLE)
|
||||||
|
|
||||||
|
config = GenerationConfig(num_beams=2)
|
||||||
|
self.assertEqual(config.get_generation_mode(), GenerationMode.BEAM_SEARCH)
|
||||||
|
|
||||||
|
config = GenerationConfig(top_k=10, do_sample=False, penalty_alpha=0.6)
|
||||||
|
self.assertEqual(config.get_generation_mode(), GenerationMode.CONTRASTIVE_SEARCH)
|
||||||
|
|
||||||
|
config = GenerationConfig()
|
||||||
|
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
class ConfigPushToHubTester(unittest.TestCase):
|
class ConfigPushToHubTester(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user