Generate: get the correct beam index on eos token (#18851)

This commit is contained in:
Joao Gante
2022-09-05 19:35:47 +01:00
committed by GitHub
parent c6d3daba54
commit d4dbd7ca59
2 changed files with 2 additions and 2 deletions

View File

@@ -259,7 +259,7 @@ class BeamSearchScorer(BeamScorer):
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (next_index,)
beam_index = beam_index + (batch_beam_idx,)
else:
beam_index = None

View File

@@ -172,7 +172,7 @@ class BeamSearchTester:
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
self.parent.assertListEqual(
expected_beam_indices + [next_indices[batch_idx, 1].item()],
expected_beam_indices + [correct_idx],
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
)