feat: Sequential beam search (#26304)

This commit is contained in:
Saibo-creator
2024-01-19 12:36:54 +01:00
committed by GitHub
parent 268fc1fdfa
commit d4fc1eb498
3 changed files with 235 additions and 43 deletions

View File

@@ -1539,6 +1539,39 @@ class GenerationTesterMixin:
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
def test_beam_search_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
"bloom",
"ctrl",
"gptbigcode",
"transo_xl",
"xlnet",
"cpm",
]
):
self.skipTest("May fix in the future: need model-specific fixes")
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2)
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
config.use_cache = True
config.is_decoder = True
# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()
low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)
high_output = model.generate(
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
@@ -2766,6 +2799,19 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
def test_beam_search_low_memory(self):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]
low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)
high_output = model.generate(
model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer