Generate: deprecate default max_length (#18018)

This commit is contained in:
Joao Gante
2022-07-23 18:02:03 +01:00
committed by GitHub
parent 8e8384663d
commit 7e44226fc7
4 changed files with 166 additions and 93 deletions

View File

@@ -2023,8 +2023,8 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_max_new_tokens_decoder_only(self):
@@ -2050,8 +2050,8 @@ class GenerationIntegrationTests(unittest.TestCase):
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_encoder_decoder_generate_with_inputs_embeds(self):