[Bug Fix] Beam search example in docs fails & a fix (integrating max_length in BeamScorer.finalize()) (#15555)
* added the test and fix * had left out a comment
This commit is contained in:
@@ -2315,6 +2315,48 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# exactly the example provided in the docstrings of beam search, which previously
|
||||
# failed after directly copying from it. Refer to PR #15555
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
|
||||
encoder_input_str = "translate English to German: How old are you?"
|
||||
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
# lets run beam search using 3 beams
|
||||
num_beams = 3
|
||||
# define decoder start token ids
|
||||
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
||||
input_ids = input_ids * model.config.decoder_start_token_id
|
||||
|
||||
# add encoder_outputs to model keyword arguments
|
||||
model_kwargs = {
|
||||
"encoder_outputs": model.get_encoder()(
|
||||
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
|
||||
)
|
||||
}
|
||||
|
||||
# instantiate beam scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=1,
|
||||
num_beams=num_beams,
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
# instantiate logits processors
|
||||
logits_processor = LogitsProcessorList(
|
||||
[
|
||||
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
|
||||
]
|
||||
)
|
||||
|
||||
outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(outputs, ["Wie alt bist du?"])
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user