[generate] vectorized beam search (#35802)

This commit is contained in:
Joao Gante
2025-03-18 18:39:36 +00:00
committed by GitHub
parent 12f2ebef63
commit 179d02ffb8
8 changed files with 426 additions and 271 deletions

View File

@@ -1099,70 +1099,6 @@ class GenerationTesterMixin:
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@pytest.mark.generate
def test_beam_search_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest(reason="May fix in the future: need custom cache handling")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
"ctrl",
"gptbigcode",
"transo_xl",
"xlnet",
"cpm",
"jamba",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
set_model_tester_for_less_flaky_test(self)
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
set_config_for_less_flaky_test(config)
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
config.use_cache = True
config.is_decoder = True
# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()
set_model_for_less_flaky_test(model)
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
low_output = model.generate(
**inputs_dict,
max_new_tokens=8,
num_beams=5,
early_stopping=True,
low_memory=True,
use_cache=True,
output_scores=True,
output_logits=True,
return_dict_in_generate=True,
**logits_processor_kwargs,
)
high_output = model.generate(
**inputs_dict,
max_new_tokens=8,
num_beams=5,
early_stopping=True,
low_memory=False,
use_cache=True,
output_scores=True,
output_logits=True,
return_dict_in_generate=True,
**logits_processor_kwargs,
)
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(low_output, high_output)
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
@@ -2964,19 +2900,6 @@ class GenerationIntegrationTests(unittest.TestCase):
torch.testing.assert_close(transition_scores_sum, outputs.sequences_scores, rtol=1e-3, atol=1e-3)
def test_beam_search_low_memory(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]
low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)
high_output = model.generate(
model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@slow
def test_green_red_watermark_generation(self):
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
@@ -4311,6 +4234,42 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertEqual(decoded_assisted, [expected_output])
@slow
def test_beam_search_advanced_stopping_criteria(self):
"""
Tests that beam search works with a stopping criteria that is not max length or EOS token. Prior to the beam
search vectorization PR (#35802), beam search was not accepting other stopping criteria. Test inspired on
the original issue (#34843).
"""
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct").to(torch_device)
prompt = (
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. "
"How many clips did Natalia sell altogether in April and May?"
)
tokens = tokenizer(prompt, return_tensors="pt").to(torch_device)
generation_config = GenerationConfig(num_beams=3, do_sample=False, length_penalty=1.0, max_new_tokens=100)
# This particular prompt should result in a ":" being present in the answer
out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer)
output_text = tokenizer.decode(out[0], skip_special_tokens=True)
last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1])
self.assertTrue(":" in output_text)
self.assertFalse(":" in output_text[-5:])
self.assertFalse(":" in last_non_special_token_decoded)
# Adding an advanced stopping criteria: text generation should stop when a ":" is generated.
# Note that:
# 1 - the text up to ":" doesn't have to be the same, it can belong to a different beam
# 2 - ":" may not be the last char, but it must be in the last non-special token
generation_config.stop_strings = ":"
out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer)
output_text = tokenizer.decode(out[0], skip_special_tokens=True)
last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1])
self.assertTrue(":" in output_text)
self.assertTrue(":" in output_text[-5:])
self.assertTrue(":" in last_non_special_token_decoded)
def test_max_time(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")