[BUG] BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch (#28201)
fix(generation/logits_process.py): BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch Co-authored-by: chenhanhui <chenhanhui@kanzhun.com>
This commit is contained in:
@@ -2138,6 +2138,7 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
|
|||||||
early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
|
early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
|
||||||
|
|
||||||
do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
|
do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
|
||||||
|
do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True)
|
||||||
scores = torch.where(do_early_stop, early_stop_scores, scores)
|
scores = torch.where(do_early_stop, early_stop_scores, scores)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|||||||
@@ -824,3 +824,19 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
|
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
|
||||||
]
|
]
|
||||||
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||||
|
|
||||||
|
def test_early_stop_processor_multi_eos(self):
|
||||||
|
input_ids = None
|
||||||
|
eos_token_id = [2, 3]
|
||||||
|
min_eos_p = 0.1 ## some small float
|
||||||
|
|
||||||
|
scores = self._get_uniform_logits(2, 4)
|
||||||
|
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
|
||||||
|
|
||||||
|
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
|
||||||
|
actual_scores = esp(input_ids, scores)
|
||||||
|
expected_scores_list = [
|
||||||
|
scores[0].tolist(),
|
||||||
|
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
|
||||||
|
]
|
||||||
|
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||||
|
|||||||
Reference in New Issue
Block a user