From ef9c3ca348cb11c4ba951f8bdab068ecd8b847ce Mon Sep 17 00:00:00 2001 From: Chan Woo Kim Date: Mon, 7 Mar 2022 17:10:18 +0900 Subject: [PATCH] [Bug Fix] Beam search example in docs fails & a fix (integrating `max_length` in `BeamScorer.finalize()`) (#15555) * added the test and fix * had left out a comment --- src/transformers/generation_beam_search.py | 5 +-- tests/generation/test_generation_utils.py | 42 ++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 8fd3f94f35..1980a5efb9 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -332,7 +332,8 @@ class BeamSearchScorer(BeamScorer): best_scores[i * self.num_beam_hyps_to_keep + j] = best_score # prepare for adding eos - sent_max_len = min(sent_lengths.max().item() + 1, max_length) + sent_lengths_max = sent_lengths.max().item() + 1 + sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): @@ -341,7 +342,7 @@ class BeamSearchScorer(BeamScorer): # fill with hypotheses and eos_token_id if the latter fits in for i, hypo in enumerate(best): decoded[i, : sent_lengths[i]] = hypo - if sent_lengths[i] < max_length: + if sent_lengths[i] < sent_max_len: decoded[i, sent_lengths[i]] = eos_token_id return UserDict( diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 9057691a20..818cbfe17e 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -2315,6 +2315,48 @@ class GenerationIntegrationTests(unittest.TestCase): self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) + @slow + def test_beam_search_example_integration(self): + # exactly the example provided in the docstrings of beam search, which previously + # failed after directly copying from it. Refer to PR #15555 + tokenizer = AutoTokenizer.from_pretrained("t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + encoder_input_str = "translate English to German: How old are you?" + encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + # lets run beam search using 3 beams + num_beams = 3 + # define decoder start token ids + input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + input_ids = input_ids * model.config.decoder_start_token_id + + # add encoder_outputs to model keyword arguments + model_kwargs = { + "encoder_outputs": model.get_encoder()( + encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ) + } + + # instantiate beam scorer + beam_scorer = BeamSearchScorer( + batch_size=1, + num_beams=num_beams, + device=model.device, + ) + + # instantiate logits processors + logits_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ] + ) + + outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual(outputs, ["Wie alt bist du?"]) + @slow def test_constrained_beam_search(self): model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)