Correct the new defaults (#34377)
* Correct the new defaults * CIs * add check * Update utils.py * Update utils.py * Add the max_length in generate test checking shape without passing length * style * CIs * fix fx CI issue
This commit is contained in:
@@ -1440,8 +1440,11 @@ class GenerationMixin:
|
||||
and not self.config.is_encoder_decoder
|
||||
):
|
||||
generation_config.max_length -= inputs_tensor.shape[1]
|
||||
else: # by default let's always generate 10 new tokens
|
||||
elif has_default_max_length: # by default let's always generate 20 new tokens
|
||||
generation_config.max_length = generation_config.max_length + input_ids_length
|
||||
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
||||
if max_position_embeddings is not None:
|
||||
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
|
||||
|
||||
# same for min length
|
||||
if generation_config.min_new_tokens is not None:
|
||||
|
||||
@@ -488,7 +488,9 @@ class EncoderDecoderMixin:
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
generated_output = enc_dec_model.generate(
|
||||
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
||||
input_ids,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
|
||||
@@ -362,7 +362,9 @@ class EncoderDecoderMixin:
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
generated_output = enc_dec_model.generate(
|
||||
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
||||
inputs,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
|
||||
@@ -306,7 +306,9 @@ class EncoderDecoderMixin:
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
generated_output = enc_dec_model.generate(
|
||||
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
||||
inputs,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
@@ -873,6 +875,7 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
generated_output = enc_dec_model.generate(
|
||||
pixel_values=pixel_values,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||
@@ -990,6 +993,7 @@ class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
|
||||
generated_output = enc_dec_model.generate(
|
||||
pixel_values=pixel_values,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||
@@ -1107,6 +1111,7 @@ class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
|
||||
generated_output = enc_dec_model.generate(
|
||||
pixel_values=pixel_values,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
Reference in New Issue
Block a user