Fix length related warnings in speculative decoding (#29585)
* avoid generation length warning * add tests * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add tests and minor fixes * refine `min_new_tokens` * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add method to prepare length arguments * add test for min length * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * fix variable naming * empty commit for tests * trigger tests (empty) --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
6cdbd73e01
commit
41579763ee
@@ -1977,6 +1977,20 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length)
|
||||
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
|
||||
|
||||
def test_min_length_if_input_embeds(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
min_length = 10
|
||||
input_len = input_ids.shape[-1]
|
||||
out_gen = model.generate(input_ids=input_ids, min_length=min_length)
|
||||
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length)
|
||||
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
|
||||
|
||||
def test_custom_stopping_criteria_overload_error(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
@@ -2539,6 +2553,56 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
model.generate(input_ids)
|
||||
self.assertEqual(len(warning_list), 0)
|
||||
|
||||
def test_length_warning_assisted_generation(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# This should not raise any warning that min length is not feasible in candidate generation
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
min_new_tokens=10,
|
||||
max_length=20,
|
||||
)
|
||||
self.assertEqual(len(warning_list), 0)
|
||||
|
||||
def test_generated_length_assisted_generation(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
input_length = input_ids.shape[-1]
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
min_new_tokens=10,
|
||||
max_new_tokens=20,
|
||||
)
|
||||
self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length))
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
min_new_tokens=10,
|
||||
)
|
||||
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)
|
||||
|
||||
def test_model_kwarg_assisted_decoding_decoder_only(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user