diff --git a/docs/source/en/main_classes/text_generation.md b/docs/source/en/main_classes/text_generation.md index a43519d5a0..dec524d257 100644 --- a/docs/source/en/main_classes/text_generation.md +++ b/docs/source/en/main_classes/text_generation.md @@ -37,6 +37,9 @@ like token streaming. - from_pretrained - from_model_config - save_pretrained + - update + - validate + - get_generation_mode ## GenerationMixin diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index e45f546cdc..178be03861 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -18,7 +18,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab _import_structure = { - "configuration_utils": ["GenerationConfig"], + "configuration_utils": ["GenerationConfig", "GenerationMode"], "streamers": ["TextIteratorStreamer", "TextStreamer"], } @@ -172,7 +172,7 @@ else: ] if TYPE_CHECKING: - from .configuration_utils import GenerationConfig + from .configuration_utils import GenerationConfig, GenerationMode from .streamers import TextIteratorStreamer, TextStreamer try: diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index cacc2dc8e8..b937b59733 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -18,12 +18,13 @@ import copy import json import os import warnings -from typing import Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from .. import __version__ from ..configuration_utils import PretrainedConfig from ..utils import ( GENERATION_CONFIG_NAME, + ExplicitEnum, PushToHubMixin, cached_file, download_url, @@ -33,10 +34,31 @@ from ..utils import ( ) +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + + logger = logging.get_logger(__name__) METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version") +class GenerationMode(ExplicitEnum): + """ + Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method. + """ + + # Non-beam methods + CONTRASTIVE_SEARCH = "contrastive_search" + GREEDY_SEARCH = "greedy_search" + SAMPLE = "sample" + ASSISTED_GENERATION = "assisted_generation" + # Beam methods + BEAM_SEARCH = "beam_search" + BEAM_SAMPLE = "beam_sample" + CONSTRAINED_BEAM_SEARCH = "constrained_beam_search" + GROUP_BEAM_SEARCH = "group_beam_search" + + class GenerationConfig(PushToHubMixin): # no-format r""" @@ -376,13 +398,65 @@ class GenerationConfig(PushToHubMixin): def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}" + def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = None) -> GenerationMode: + """ + Returns the generation mode triggered by the [`GenerationConfig`] instance. + + Arg: + assistant_model (`PreTrainedModel`, *optional*): + The assistant model to be used for assisted generation. If set, the generation mode will be + assisted generation. + + Returns: + `GenerationMode`: The generation mode triggered by the instance. + """ + # TODO joao: find out a way of not depending on external fields (e.g. `assistant_model`), then make this a + # property and part of the `__repr__` + if self.constraints is not None or self.force_words_ids is not None: + generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH + elif self.num_beams == 1: + if self.do_sample is False: + if ( + self.top_k is not None + and self.top_k > 1 + and self.penalty_alpha is not None + and self.penalty_alpha > 0 + ): + generation_mode = GenerationMode.CONTRASTIVE_SEARCH + else: + generation_mode = GenerationMode.GREEDY_SEARCH + else: + generation_mode = GenerationMode.SAMPLE + else: + if self.num_beam_groups > 1: + generation_mode = GenerationMode.GROUP_BEAM_SEARCH + elif self.do_sample is True: + generation_mode = GenerationMode.BEAM_SAMPLE + else: + generation_mode = GenerationMode.BEAM_SEARCH + + # Assisted generation may extend some generation modes + if assistant_model is not None or self.prompt_lookup_num_tokens is not None: + if generation_mode in ("greedy_search", "sample"): + generation_mode = GenerationMode.ASSISTED_GENERATION + else: + raise ValueError( + "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " + "is only supported with Greedy Search and Sample." + ) + return generation_mode + def validate(self, is_init=False): """ Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence of parameterization that can be detected as incorrect from the configuration instance alone. - Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the - model, such as parameters related to the generation length. + Note that some parameters not validated here are best validated at generate runtime, as they may depend on + other inputs and/or the model, such as parameters related to the generation length. + + Arg: + is_init (`bool`, *optional*, defaults to `False`): + Whether the validation is performed during the initialization of the instance. """ # Validation of individual attributes diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2f68c8b2f4..d6207fc354 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -34,7 +34,7 @@ from ..models.auto import ( MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) -from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging +from ..utils import ModelOutput, is_accelerate_available, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( @@ -45,7 +45,7 @@ from .candidate_generator import ( _prepare_attention_mask, _prepare_token_type_ids, ) -from .configuration_utils import GenerationConfig +from .configuration_utils import GenerationConfig, GenerationMode from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor, @@ -325,23 +325,6 @@ GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDec GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] -class GenerationMode(ExplicitEnum): - """ - Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method. - """ - - # Non-beam methods - CONTRASTIVE_SEARCH = "contrastive_search" - GREEDY_SEARCH = "greedy_search" - SAMPLE = "sample" - ASSISTED_GENERATION = "assisted_generation" - # Beam methods - BEAM_SEARCH = "beam_search" - BEAM_SAMPLE = "beam_sample" - CONSTRAINED_BEAM_SEARCH = "constrained_beam_search" - GROUP_BEAM_SEARCH = "group_beam_search" - - class GenerationMixin: """ A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. @@ -764,46 +747,6 @@ class GenerationMixin: warpers.append(LogitNormalization()) return warpers - def _get_generation_mode( - self, generation_config: GenerationConfig, assistant_model: Optional["PreTrainedModel"] - ) -> GenerationMode: - """ - Returns the generation mode triggered by a [`GenerationConfig`] instance. - """ - if generation_config.constraints is not None or generation_config.force_words_ids is not None: - generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH - elif generation_config.num_beams == 1: - if generation_config.do_sample is False: - if ( - generation_config.top_k is not None - and generation_config.top_k > 1 - and generation_config.penalty_alpha is not None - and generation_config.penalty_alpha > 0 - ): - generation_mode = GenerationMode.CONTRASTIVE_SEARCH - else: - generation_mode = GenerationMode.GREEDY_SEARCH - else: - generation_mode = GenerationMode.SAMPLE - else: - if generation_config.num_beam_groups > 1: - generation_mode = GenerationMode.GROUP_BEAM_SEARCH - elif generation_config.do_sample is True: - generation_mode = GenerationMode.BEAM_SAMPLE - else: - generation_mode = GenerationMode.BEAM_SEARCH - - # Assisted generation may extend some generation modes - if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None: - if generation_mode in ("greedy_search", "sample"): - generation_mode = GenerationMode.ASSISTED_GENERATION - else: - raise ValueError( - "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " - "is only supported with Greedy Search and Sample." - ) - return generation_mode - def _get_logits_processor( self, generation_config: GenerationConfig, @@ -1474,7 +1417,7 @@ class GenerationMixin: self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. determine generation mode - generation_mode = self._get_generation_mode(generation_config, assistant_model) + generation_mode = generation_config.get_generation_mode(assistant_model) if streamer is not None and (generation_config.num_beams > 1): raise ValueError( diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index a86dd31440..ece3f33a06 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -24,6 +24,7 @@ from parameterized import parameterized from requests.exceptions import HTTPError from transformers import AutoConfig, GenerationConfig +from transformers.generation import GenerationMode from transformers.testing_utils import TOKEN, USER, is_staging_test @@ -202,6 +203,23 @@ class GenerationConfigTest(unittest.TestCase): self.assertEqual(len(captured_warnings), 0) self.assertTrue(len(os.listdir(tmp_dir)) == 1) + def test_generation_mode(self): + """Tests that the `get_generation_mode` method is working as expected.""" + config = GenerationConfig() + self.assertEqual(config.get_generation_mode(), GenerationMode.GREEDY_SEARCH) + + config = GenerationConfig(do_sample=True) + self.assertEqual(config.get_generation_mode(), GenerationMode.SAMPLE) + + config = GenerationConfig(num_beams=2) + self.assertEqual(config.get_generation_mode(), GenerationMode.BEAM_SEARCH) + + config = GenerationConfig(top_k=10, do_sample=False, penalty_alpha=0.6) + self.assertEqual(config.get_generation_mode(), GenerationMode.CONTRASTIVE_SEARCH) + + config = GenerationConfig() + self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION) + @is_staging_test class ConfigPushToHubTester(unittest.TestCase):