Generate: get the correct beam index on eos token (#18851)
This commit is contained in:
@@ -259,7 +259,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
continue
|
continue
|
||||||
if beam_indices is not None:
|
if beam_indices is not None:
|
||||||
beam_index = beam_indices[batch_beam_idx]
|
beam_index = beam_indices[batch_beam_idx]
|
||||||
beam_index = beam_index + (next_index,)
|
beam_index = beam_index + (batch_beam_idx,)
|
||||||
else:
|
else:
|
||||||
beam_index = None
|
beam_index = None
|
||||||
|
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ class BeamSearchTester:
|
|||||||
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
|
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
expected_beam_indices + [next_indices[batch_idx, 1].item()],
|
expected_beam_indices + [correct_idx],
|
||||||
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
|
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user