Correct the new defaults (#34377)
* Correct the new defaults * CIs * add check * Update utils.py * Update utils.py * Add the max_length in generate test checking shape without passing length * style * CIs * fix fx CI issue
This commit is contained in:
@@ -306,7 +306,9 @@ class EncoderDecoderMixin:
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
generated_output = enc_dec_model.generate(
|
||||
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
||||
inputs,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
@@ -873,6 +875,7 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
generated_output = enc_dec_model.generate(
|
||||
pixel_values=pixel_values,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||
@@ -990,6 +993,7 @@ class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
|
||||
generated_output = enc_dec_model.generate(
|
||||
pixel_values=pixel_values,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||
@@ -1107,6 +1111,7 @@ class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
|
||||
generated_output = enc_dec_model.generate(
|
||||
pixel_values=pixel_values,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||
max_length=decoder_config.max_length,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
Reference in New Issue
Block a user