[BUG] BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch (#28201)
fix(generation/logits_process.py): BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch Co-authored-by: chenhanhui <chenhanhui@kanzhun.com>
This commit is contained in:
@@ -824,3 +824,19 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
[float("-inf"), float("-inf"), scores[0][0], float("-inf")],
|
||||
]
|
||||
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||
|
||||
def test_early_stop_processor_multi_eos(self):
|
||||
input_ids = None
|
||||
eos_token_id = [2, 3]
|
||||
min_eos_p = 0.1 ## some small float
|
||||
|
||||
scores = self._get_uniform_logits(2, 4)
|
||||
scores[0][eos_token_id] = -6 ## less than log(min_eos_p)
|
||||
|
||||
esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p)
|
||||
actual_scores = esp(input_ids, scores)
|
||||
expected_scores_list = [
|
||||
scores[0].tolist(),
|
||||
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
|
||||
]
|
||||
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
|
||||
|
||||
Reference in New Issue
Block a user