From b0bf3011c113a0e3408776d0ad71cae4b0e70c36 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 11 Apr 2022 11:55:30 +0100 Subject: [PATCH] Generate: min length can't be larger than max length (#16668) * min length must be smaller than max length * Update min_length in tests --- src/transformers/generation_flax_utils.py | 7 ++++++- src/transformers/generation_tf_utils.py | 5 +++++ src/transformers/generation_utils.py | 7 ++++++- tests/generation/test_generation_utils.py | 2 +- 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index 75fa54bdce..e6ca8b0fcc 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -259,6 +259,7 @@ class FlaxGenerationMixin: ```""" # set init values max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id @@ -269,6 +270,11 @@ class FlaxGenerationMixin: if decoder_start_token_id is None and self.config.is_encoder_decoder: raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.") + if min_length is not None and min_length > max_length: + raise ValueError( + f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " + f"length ({max_length})" + ) if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs @@ -389,7 +395,6 @@ class FlaxGenerationMixin: no_repeat_ngram_size = ( no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size ) - min_length = min_length if min_length is not None else self.config.min_length eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id forced_bos_token_id = ( forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 8668e6f8dc..b9ae5bd77b 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -1489,6 +1489,11 @@ class TFGenerationMixin: if pad_token_id is None and eos_token_id is not None: logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence") pad_token_id = eos_token_id + if min_length is not None and min_length > max_length: + raise ValueError( + f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " + f"length ({max_length})" + ) # 2. Define model inputs input_ids = self._prepare_model_inputs(input_ids, bos_token_id) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 1472052d2b..b37e11af03 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -700,7 +700,6 @@ class GenerationMixin: else self.config.encoder_no_repeat_ngram_size ) bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids - min_length = min_length if min_length is not None else self.config.min_length eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty forced_bos_token_id = ( @@ -1185,7 +1184,13 @@ class GenerationMixin: ) # default to config if still None max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length + if min_length is not None and min_length > max_length: + raise ValueError( + f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " + f"length ({max_length})" + ) if input_ids_seq_length >= max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 031061041e..6006dbe21c 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -102,7 +102,7 @@ class GenerationTesterMixin: diversity_penalty=None, ): process_kwargs = { - "min_length": input_length + 1, + "min_length": input_length + 1 if max_length is None else max_length - 1, "bad_words_ids": [[1, 0]], "no_repeat_ngram_size": 2, "repetition_penalty": 1.2,