Correct the new defaults (#34377)

* Correct the new defaults

* CIs

* add check

* Update utils.py

* Update utils.py

* Add the max_length in generate test checking shape without passing length

* style

* CIs

* fix fx CI issue
This commit is contained in:
Cyril Vallez
2024-10-24 18:42:03 +02:00
committed by GitHub
parent 1c5918d910
commit 4c6e0c9252
4 changed files with 16 additions and 4 deletions

View File

@@ -488,7 +488,9 @@ class EncoderDecoderMixin:
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
input_ids,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))

View File

@@ -362,7 +362,9 @@ class EncoderDecoderMixin:
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))

View File

@@ -306,7 +306,9 @@ class EncoderDecoderMixin:
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
@@ -873,6 +875,7 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
@@ -990,6 +993,7 @@ class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
@@ -1107,6 +1111,7 @@ class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))