[Whisper] Fix forced decoder ids (#20652)

* [Whisper] Fix forced decoder ids

* fix test
This commit is contained in:
Sanchit Gandhi
2022-12-07 16:44:13 +00:00
committed by GitHub
parent 7c5eaf9e5a
commit 77382e918d
2 changed files with 7 additions and 3 deletions

View File

@@ -26,7 +26,6 @@ if is_speech_available():
from transformers import WhisperFeatureExtractor, WhisperProcessor
START_OF_TRANSCRIPT = 50257
TRANSCRIBE = 50358
NOTIMESTAMPS = 50362
@@ -145,5 +144,5 @@ class WhisperProcessorTest(unittest.TestCase):
for ids in forced_decoder_ids:
self.assertIsInstance(ids, (list, tuple))
expected_ids = [START_OF_TRANSCRIPT, TRANSCRIBE, NOTIMESTAMPS]
expected_ids = [TRANSCRIBE, NOTIMESTAMPS]
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)