[Whisper] Fix forced decoder ids (#20652)
* [Whisper] Fix forced decoder ids * fix test
This commit is contained in:
@@ -584,5 +584,10 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
|
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
|
||||||
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
|
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
|
||||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(self.prefix_tokens)]
|
# prefix tokens are of the form: <|startoftranscript|> <|lang_id|> <|task|> <|notimestamps|>
|
||||||
|
# we don't want to force the bos token at position 1, as this is the starting token
|
||||||
|
# when we generate, so we slice the prefix tokens to: <|lang_id|> <|task|> <|notimestamps|>
|
||||||
|
# to get the forced tokens
|
||||||
|
forced_tokens = self.prefix_tokens[1:]
|
||||||
|
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
|
||||||
return forced_decoder_ids
|
return forced_decoder_ids
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ if is_speech_available():
|
|||||||
from transformers import WhisperFeatureExtractor, WhisperProcessor
|
from transformers import WhisperFeatureExtractor, WhisperProcessor
|
||||||
|
|
||||||
|
|
||||||
START_OF_TRANSCRIPT = 50257
|
|
||||||
TRANSCRIBE = 50358
|
TRANSCRIBE = 50358
|
||||||
NOTIMESTAMPS = 50362
|
NOTIMESTAMPS = 50362
|
||||||
|
|
||||||
@@ -145,5 +144,5 @@ class WhisperProcessorTest(unittest.TestCase):
|
|||||||
for ids in forced_decoder_ids:
|
for ids in forced_decoder_ids:
|
||||||
self.assertIsInstance(ids, (list, tuple))
|
self.assertIsInstance(ids, (list, tuple))
|
||||||
|
|
||||||
expected_ids = [START_OF_TRANSCRIPT, TRANSCRIBE, NOTIMESTAMPS]
|
expected_ids = [TRANSCRIBE, NOTIMESTAMPS]
|
||||||
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
|
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user