Correct the new defaults (#34377)
* Correct the new defaults * CIs * add check * Update utils.py * Update utils.py * Add the max_length in generate test checking shape without passing length * style * CIs * fix fx CI issue
This commit is contained in:
@@ -1440,8 +1440,11 @@ class GenerationMixin:
|
|||||||
and not self.config.is_encoder_decoder
|
and not self.config.is_encoder_decoder
|
||||||
):
|
):
|
||||||
generation_config.max_length -= inputs_tensor.shape[1]
|
generation_config.max_length -= inputs_tensor.shape[1]
|
||||||
else: # by default let's always generate 10 new tokens
|
elif has_default_max_length: # by default let's always generate 20 new tokens
|
||||||
generation_config.max_length = generation_config.max_length + input_ids_length
|
generation_config.max_length = generation_config.max_length + input_ids_length
|
||||||
|
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
||||||
|
if max_position_embeddings is not None:
|
||||||
|
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
|
||||||
|
|
||||||
# same for min length
|
# same for min length
|
||||||
if generation_config.min_new_tokens is not None:
|
if generation_config.min_new_tokens is not None:
|
||||||
|
|||||||
@@ -488,7 +488,9 @@ class EncoderDecoderMixin:
|
|||||||
|
|
||||||
# Bert does not have a bos token id, so use pad_token_id instead
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
generated_output = enc_dec_model.generate(
|
generated_output = enc_dec_model.generate(
|
||||||
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
input_ids,
|
||||||
|
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
|
||||||
|
max_length=decoder_config.max_length,
|
||||||
)
|
)
|
||||||
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
|
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
|
||||||
|
|
||||||
|
|||||||
@@ -362,7 +362,9 @@ class EncoderDecoderMixin:
|
|||||||
|
|
||||||
# Bert does not have a bos token id, so use pad_token_id instead
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
generated_output = enc_dec_model.generate(
|
generated_output = enc_dec_model.generate(
|
||||||
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
inputs,
|
||||||
|
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
|
||||||
|
max_length=decoder_config.max_length,
|
||||||
)
|
)
|
||||||
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
||||||
|
|
||||||
|
|||||||
@@ -306,7 +306,9 @@ class EncoderDecoderMixin:
|
|||||||
|
|
||||||
# Bert does not have a bos token id, so use pad_token_id instead
|
# Bert does not have a bos token id, so use pad_token_id instead
|
||||||
generated_output = enc_dec_model.generate(
|
generated_output = enc_dec_model.generate(
|
||||||
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
inputs,
|
||||||
|
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
|
||||||
|
max_length=decoder_config.max_length,
|
||||||
)
|
)
|
||||||
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
||||||
|
|
||||||
@@ -873,6 +875,7 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
generated_output = enc_dec_model.generate(
|
generated_output = enc_dec_model.generate(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||||
|
max_length=decoder_config.max_length,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||||
@@ -990,6 +993,7 @@ class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
generated_output = enc_dec_model.generate(
|
generated_output = enc_dec_model.generate(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||||
|
max_length=decoder_config.max_length,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||||
@@ -1107,6 +1111,7 @@ class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
generated_output = enc_dec_model.generate(
|
generated_output = enc_dec_model.generate(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||||
|
max_length=decoder_config.max_length,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||||
|
|||||||
Reference in New Issue
Block a user