From 4df1d696340794cc9c03f4d14482375207f1e7a7 Mon Sep 17 00:00:00 2001 From: HanHui Date: Wed, 10 Jan 2024 18:46:49 +0800 Subject: [PATCH] [BUG] BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch (#28201) fix(generation/logits_process.py): BarkEosPrioritizerLogitsProcessor eos_token_id use list, tensor size mismatch Co-authored-by: chenhanhui --- src/transformers/generation/logits_process.py | 1 + tests/generation/test_logits_process.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index dea6f44c3a..2b1b9f5a50 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -2138,6 +2138,7 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p + do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True) scores = torch.where(do_early_stop, early_stop_scores, scores) return scores diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index b1b3602c92..95150a9c33 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -824,3 +824,19 @@ class LogitsProcessorTest(unittest.TestCase): [float("-inf"), float("-inf"), scores[0][0], float("-inf")], ] self.assertListEqual(actual_scores.tolist(), expected_scores_list) + + def test_early_stop_processor_multi_eos(self): + input_ids = None + eos_token_id = [2, 3] + min_eos_p = 0.1 ## some small float + + scores = self._get_uniform_logits(2, 4) + scores[0][eos_token_id] = -6 ## less than log(min_eos_p) + + esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) + actual_scores = esp(input_ids, scores) + expected_scores_list = [ + scores[0].tolist(), + [float("-inf"), float("-inf"), scores[0][0], scores[0][0]], + ] + self.assertListEqual(actual_scores.tolist(), expected_scores_list)