[Bug Fix] Beam search example in docs fails & a fix (integrating max_length in BeamScorer.finalize()) (#15555)
* added the test and fix * had left out a comment
This commit is contained in:
@@ -332,7 +332,8 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||||
|
|
||||||
# prepare for adding eos
|
# prepare for adding eos
|
||||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
sent_lengths_max = sent_lengths.max().item() + 1
|
||||||
|
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
||||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||||
# shorter batches are padded if needed
|
# shorter batches are padded if needed
|
||||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||||
@@ -341,7 +342,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
# fill with hypotheses and eos_token_id if the latter fits in
|
# fill with hypotheses and eos_token_id if the latter fits in
|
||||||
for i, hypo in enumerate(best):
|
for i, hypo in enumerate(best):
|
||||||
decoded[i, : sent_lengths[i]] = hypo
|
decoded[i, : sent_lengths[i]] = hypo
|
||||||
if sent_lengths[i] < max_length:
|
if sent_lengths[i] < sent_max_len:
|
||||||
decoded[i, sent_lengths[i]] = eos_token_id
|
decoded[i, sent_lengths[i]] = eos_token_id
|
||||||
|
|
||||||
return UserDict(
|
return UserDict(
|
||||||
|
|||||||
@@ -2315,6 +2315,48 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_beam_search_example_integration(self):
|
||||||
|
# exactly the example provided in the docstrings of beam search, which previously
|
||||||
|
# failed after directly copying from it. Refer to PR #15555
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||||
|
|
||||||
|
encoder_input_str = "translate English to German: How old are you?"
|
||||||
|
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
# lets run beam search using 3 beams
|
||||||
|
num_beams = 3
|
||||||
|
# define decoder start token ids
|
||||||
|
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
||||||
|
input_ids = input_ids * model.config.decoder_start_token_id
|
||||||
|
|
||||||
|
# add encoder_outputs to model keyword arguments
|
||||||
|
model_kwargs = {
|
||||||
|
"encoder_outputs": model.get_encoder()(
|
||||||
|
encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# instantiate beam scorer
|
||||||
|
beam_scorer = BeamSearchScorer(
|
||||||
|
batch_size=1,
|
||||||
|
num_beams=num_beams,
|
||||||
|
device=model.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# instantiate logits processors
|
||||||
|
logits_processor = LogitsProcessorList(
|
||||||
|
[
|
||||||
|
MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
||||||
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertListEqual(outputs, ["Wie alt bist du?"])
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_constrained_beam_search(self):
|
def test_constrained_beam_search(self):
|
||||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user