Correct the new defaults (#34377)

* Correct the new defaults

* CIs

* add check

* Update utils.py

* Update utils.py

* Add the max_length in generate test checking shape without passing length

* style

* CIs

* fix fx CI issue
This commit is contained in:
Cyril Vallez
2024-10-24 18:42:03 +02:00
committed by GitHub
parent 1c5918d910
commit 4c6e0c9252
4 changed files with 16 additions and 4 deletions

View File

@@ -362,7 +362,9 @@ class EncoderDecoderMixin:
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))