Generate: deprecate default max_length (#18018)
This commit is contained in:
@@ -2023,8 +2023,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
# max_new_tokens and max_length serve the same purpose and should not be used together.
|
||||
with self.assertWarns(UserWarning):
|
||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
||||
with self.assertRaises(ValueError):
|
||||
bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
|
||||
|
||||
def test_max_new_tokens_decoder_only(self):
|
||||
@@ -2050,8 +2050,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
# max_new_tokens and max_length serve the same purpose and should not be used together.
|
||||
with self.assertWarns(UserWarning):
|
||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
||||
with self.assertRaises(ValueError):
|
||||
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
|
||||
|
||||
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
||||
|
||||
Reference in New Issue
Block a user