[generate] ✨ vectorized beam search ✨ (#35802)
This commit is contained in:
@@ -1099,70 +1099,6 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_beam_search_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="May fix in the future: need custom cache handling")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
"ctrl",
|
||||
"gptbigcode",
|
||||
"transo_xl",
|
||||
"xlnet",
|
||||
"cpm",
|
||||
"jamba",
|
||||
]
|
||||
):
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
|
||||
set_model_tester_for_less_flaky_test(self)
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
set_config_for_less_flaky_test(config)
|
||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
# test output equality of low versus high memory
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
set_model_for_less_flaky_test(model)
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
|
||||
|
||||
low_output = model.generate(
|
||||
**inputs_dict,
|
||||
max_new_tokens=8,
|
||||
num_beams=5,
|
||||
early_stopping=True,
|
||||
low_memory=True,
|
||||
use_cache=True,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
return_dict_in_generate=True,
|
||||
**logits_processor_kwargs,
|
||||
)
|
||||
|
||||
high_output = model.generate(
|
||||
**inputs_dict,
|
||||
max_new_tokens=8,
|
||||
num_beams=5,
|
||||
early_stopping=True,
|
||||
low_memory=False,
|
||||
use_cache=True,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
return_dict_in_generate=True,
|
||||
**logits_processor_kwargs,
|
||||
)
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(low_output, high_output)
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
@@ -2964,19 +2900,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
torch.testing.assert_close(transition_scores_sum, outputs.sequences_scores, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_beam_search_low_memory(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]
|
||||
|
||||
low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)
|
||||
|
||||
high_output = model.generate(
|
||||
model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@slow
|
||||
def test_green_red_watermark_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
@@ -4311,6 +4234,42 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(decoded_assisted, [expected_output])
|
||||
|
||||
@slow
|
||||
def test_beam_search_advanced_stopping_criteria(self):
|
||||
"""
|
||||
Tests that beam search works with a stopping criteria that is not max length or EOS token. Prior to the beam
|
||||
search vectorization PR (#35802), beam search was not accepting other stopping criteria. Test inspired on
|
||||
the original issue (#34843).
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct").to(torch_device)
|
||||
|
||||
prompt = (
|
||||
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. "
|
||||
"How many clips did Natalia sell altogether in April and May?"
|
||||
)
|
||||
tokens = tokenizer(prompt, return_tensors="pt").to(torch_device)
|
||||
generation_config = GenerationConfig(num_beams=3, do_sample=False, length_penalty=1.0, max_new_tokens=100)
|
||||
|
||||
# This particular prompt should result in a ":" being present in the answer
|
||||
out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer)
|
||||
output_text = tokenizer.decode(out[0], skip_special_tokens=True)
|
||||
last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1])
|
||||
self.assertTrue(":" in output_text)
|
||||
self.assertFalse(":" in output_text[-5:])
|
||||
self.assertFalse(":" in last_non_special_token_decoded)
|
||||
|
||||
# Adding an advanced stopping criteria: text generation should stop when a ":" is generated.
|
||||
# Note that:
|
||||
# 1 - the text up to ":" doesn't have to be the same, it can belong to a different beam
|
||||
# 2 - ":" may not be the last char, but it must be in the last non-special token
|
||||
generation_config.stop_strings = ":"
|
||||
out = model.generate(**tokens, generation_config=generation_config, tokenizer=tokenizer)
|
||||
output_text = tokenizer.decode(out[0], skip_special_tokens=True)
|
||||
last_non_special_token_decoded = tokenizer.decode(out[out != tokenizer.pad_token_id][-1])
|
||||
self.assertTrue(":" in output_text)
|
||||
self.assertTrue(":" in output_text[-5:])
|
||||
self.assertTrue(":" in last_non_special_token_decoded)
|
||||
|
||||
def test_max_time(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
|
||||
|
||||
@@ -104,10 +104,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cohere2 has HybridCache and doesn't support low_memory generation")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@@ -119,10 +119,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support low_memory generation")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@@ -332,12 +332,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Qwen2.5-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that"
|
||||
)
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs"
|
||||
)
|
||||
|
||||
@@ -344,12 +344,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Qwen2-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that"
|
||||
)
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user