Correct the new defaults (#34377)

* Correct the new defaults

* CIs

* add check

* Update utils.py

* Update utils.py

* Add the max_length in generate test checking shape without passing length

* style

* CIs

* fix fx CI issue
This commit is contained in:
Cyril Vallez
2024-10-24 18:42:03 +02:00
committed by GitHub
parent 1c5918d910
commit 4c6e0c9252
4 changed files with 16 additions and 4 deletions

View File

@@ -1440,8 +1440,11 @@ class GenerationMixin:
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]
else: # by default let's always generate 10 new tokens
elif has_default_max_length: # by default let's always generate 20 new tokens
generation_config.max_length = generation_config.max_length + input_ids_length
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
if max_position_embeddings is not None:
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
# same for min length
if generation_config.min_new_tokens is not None:

View File

@@ -488,7 +488,9 @@ class EncoderDecoderMixin:
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
input_ids,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))

View File

@@ -362,7 +362,9 @@ class EncoderDecoderMixin:
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))

View File

@@ -306,7 +306,9 @@ class EncoderDecoderMixin:
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
@@ -873,6 +875,7 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
@@ -990,6 +993,7 @@ class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
@@ -1107,6 +1111,7 @@ class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))