From d4dbd7ca59bd50dd034e7995cb36e5efed3d9512 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 5 Sep 2022 19:35:47 +0100 Subject: [PATCH] Generate: get the correct beam index on eos token (#18851) --- src/transformers/generation_beam_search.py | 2 +- tests/generation/test_generation_beam_search.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index e0514edafb..7c50c0d7ac 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -259,7 +259,7 @@ class BeamSearchScorer(BeamScorer): continue if beam_indices is not None: beam_index = beam_indices[batch_beam_idx] - beam_index = beam_index + (next_index,) + beam_index = beam_index + (batch_beam_idx,) else: beam_index = None diff --git a/tests/generation/test_generation_beam_search.py b/tests/generation/test_generation_beam_search.py index 885cefa62c..66bfc29b54 100644 --- a/tests/generation/test_generation_beam_search.py +++ b/tests/generation/test_generation_beam_search.py @@ -172,7 +172,7 @@ class BeamSearchTester: input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist() ) self.parent.assertListEqual( - expected_beam_indices + [next_indices[batch_idx, 1].item()], + expected_beam_indices + [correct_idx], torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(), )