Correct the new defaults (#34377)
* Correct the new defaults * CIs * add check * Update utils.py * Update utils.py * Add the max_length in generate test checking shape without passing length * style * CIs * fix fx CI issue
This commit is contained in:
@@ -362,7 +362,9 @@ class EncoderDecoderMixin:
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
generated_output = enc_dec_model.generate(
|
||||
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
||||
inputs,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user