Generate: get the correct beam index on eos token (#18851)
This commit is contained in:
@@ -172,7 +172,7 @@ class BeamSearchTester:
|
||||
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
expected_beam_indices + [next_indices[batch_idx, 1].item()],
|
||||
expected_beam_indices + [correct_idx],
|
||||
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user