Generate: fix default max length warning (#25539)

This commit is contained in:
Joao Gante
2023-08-16 15:30:54 +01:00
committed by GitHub
parent e13d5b6048
commit 3f9cb33504
5 changed files with 30 additions and 4 deletions

View File

@@ -16,6 +16,7 @@
import inspect
import unittest
import warnings
import numpy as np
@@ -2844,3 +2845,28 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
with self.assertRaises(TypeError):
# FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo"
bart_model.generate(input_ids, foo="bar")
def test_default_max_length_warning(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Default generation config value of 20 -> emits warning
with self.assertWarns(UserWarning):
model.generate(input_ids)
# Explicitly setting max_length to 20 -> no warning
with warnings.catch_warnings(record=True) as warning_list:
model.generate(input_ids, max_length=20)
self.assertEqual(len(warning_list), 0)
# Generation config max_length != 20 -> no warning
with warnings.catch_warnings(record=True) as warning_list:
model.generation_config.max_length = 10
model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)