Generate: fix default max length warning (#25539)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -2844,3 +2845,28 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
with self.assertRaises(TypeError):
|
||||
# FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo"
|
||||
bart_model.generate(input_ids, foo="bar")
|
||||
|
||||
def test_default_max_length_warning(self):
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Default generation config value of 20 -> emits warning
|
||||
with self.assertWarns(UserWarning):
|
||||
model.generate(input_ids)
|
||||
|
||||
# Explicitly setting max_length to 20 -> no warning
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
model.generate(input_ids, max_length=20)
|
||||
self.assertEqual(len(warning_list), 0)
|
||||
|
||||
# Generation config max_length != 20 -> no warning
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
model.generation_config.max_length = 10
|
||||
model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence
|
||||
model.generate(input_ids)
|
||||
self.assertEqual(len(warning_list), 0)
|
||||
|
||||
Reference in New Issue
Block a user