Generate: min length can't be larger than max length (#16668)
* min length must be smaller than max length * Update min_length in tests
This commit is contained in:
@@ -259,6 +259,7 @@ class FlaxGenerationMixin:
|
||||
```"""
|
||||
# set init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
@@ -269,6 +270,11 @@ class FlaxGenerationMixin:
|
||||
|
||||
if decoder_start_token_id is None and self.config.is_encoder_decoder:
|
||||
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
|
||||
if min_length is not None and min_length > max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
|
||||
f"length ({max_length})"
|
||||
)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
# add encoder_outputs to model_kwargs
|
||||
@@ -389,7 +395,6 @@ class FlaxGenerationMixin:
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
)
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
forced_bos_token_id = (
|
||||
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
|
||||
|
||||
@@ -1489,6 +1489,11 @@ class TFGenerationMixin:
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
|
||||
pad_token_id = eos_token_id
|
||||
if min_length is not None and min_length > max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
|
||||
f"length ({max_length})"
|
||||
)
|
||||
|
||||
# 2. Define model inputs
|
||||
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
|
||||
|
||||
@@ -700,7 +700,6 @@ class GenerationMixin:
|
||||
else self.config.encoder_no_repeat_ngram_size
|
||||
)
|
||||
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
|
||||
forced_bos_token_id = (
|
||||
@@ -1185,7 +1184,13 @@ class GenerationMixin:
|
||||
)
|
||||
# default to config if still None
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
|
||||
if min_length is not None and min_length > max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
|
||||
f"length ({max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= max_length:
|
||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
|
||||
@@ -102,7 +102,7 @@ class GenerationTesterMixin:
|
||||
diversity_penalty=None,
|
||||
):
|
||||
process_kwargs = {
|
||||
"min_length": input_length + 1,
|
||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
||||
"bad_words_ids": [[1, 0]],
|
||||
"no_repeat_ngram_size": 2,
|
||||
"repetition_penalty": 1.2,
|
||||
|
||||
Reference in New Issue
Block a user