feat: Sequential beam search (#26304)
This commit is contained in:
@@ -1539,6 +1539,39 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
def test_beam_search_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
"bloom",
|
||||
"ctrl",
|
||||
"gptbigcode",
|
||||
"transo_xl",
|
||||
"xlnet",
|
||||
"cpm",
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2)
|
||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
# test output equality of low versus high memory
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)
|
||||
|
||||
high_output = model.generate(
|
||||
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
@@ -2766,6 +2799,19 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_beam_search_low_memory(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]
|
||||
|
||||
low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)
|
||||
|
||||
high_output = model.generate(
|
||||
model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# PT-only test: TF doesn't have a BeamSearchScorer
|
||||
|
||||
Reference in New Issue
Block a user