[Whisper Tokenizer] Make more user-friendly (#19921)

* [Whisper Tokenizer] Make more user-friendly

* use property

* make indexing rigorous

* small clean-up

* tests

* skip seq2seq tests

* remove multilingual arg

* reorder args

* collapse to one function

Co-authored-by: ArthurZucker <arthur@huggingface.co>

* option to override attributes

Co-authored-by: ArthurZucker <arthur@huggingface.co>

* add to docs

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make comment more clear

Co-authored-by: sgugger <sylvain@huggingface.co>

* don't add special tokens in get_decoder_prompt_ids

* add test for set_prefix_tokens

Co-authored-by: ArthurZucker <arthur@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: sgugger <sylvain@huggingface.co>
This commit is contained in:
Sanchit Gandhi
2022-11-03 14:22:40 +00:00
committed by GitHub
parent 790ff2544a
commit 06d488061f
4 changed files with 277 additions and 59 deletions

View File

@@ -20,14 +20,20 @@ from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin
EN_CODE = 50258
ES_CODE = 50256
ES_CODE = 50262
EN_CODE = 50259
END_OF_TRANSCRIPT = 50257
START_OF_TRANSCRIPT = 50258
TRANSLATE = 50358
TRANSCRIBE = 50359
NOTIMESTAMPS = 50363
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = WhisperTokenizer
test_rust_tokenizer = False
test_sentencepiece = False
test_seq2seq = False
def setUp(self):
super().setUp()
@@ -101,13 +107,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"
transcript = (
"'<|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> Nor is Mr. Quilters manner less interesting"
" than his matter.<|endoftext|>'"
)
clean_transcript = " Nor is Mr. Quilters manner less interesting than his matter."
french_text = "Bonjour! Il me semble que Mrs Quilters n'était pas présente"
@classmethod
def setUpClass(cls):
cls.tokenizer: WhisperTokenizer = WhisperTokenizer.from_pretrained(cls.checkpoint_name)
@@ -115,15 +114,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
def test_tokenizer_equivalence(self):
text = "다람쥐 헌 쳇바퀴에 타고파"
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="ko")
gpt2_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="korean")
monolingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
gpt2_tokens = gpt2_tokenizer.encode(text)
multilingual_tokens = multilingual_tokenizer.encode(text)
monolingual_tokens = monolingual_tokenizer.encode(text, add_special_tokens=False)
multilingual_tokens = multilingual_tokenizer.encode(text, add_special_tokens=False)
assert gpt2_tokenizer.decode(gpt2_tokens) == text
assert monolingual_tokenizer.decode(monolingual_tokens) == text
assert multilingual_tokenizer.decode(multilingual_tokens) == text
assert len(gpt2_tokens) > len(multilingual_tokens)
assert len(monolingual_tokens) > len(multilingual_tokens)
# fmt: off
EXPECTED_ENG = [
@@ -138,35 +137,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
]
# fmt: on
self.assertListEqual(gpt2_tokens, EXPECTED_ENG)
self.assertListEqual(monolingual_tokens, EXPECTED_ENG)
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
def test_tokenizer_special(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
text = "<|startoftranscript|>Hey! How are you feeling? J'ai l'impression que 郷さん est prêt<|endoftext|>"
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
"openai/whisper-tiny", language="english", task="transcribe"
)
text = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
multilingual_tokens = multilingual_tokenizer.encode(text)
# fmt: off
# format: <|startoftranscript|> <|lang-id|> <|task|> <|notimestamps|> ... transcription ids ... <|endoftext|>
EXPECTED_MULTI = [
50257, 10814, 0, 1374, 389, 345, 4203, 30, 449, 6,
1872, 300, 6, 11011, 2234, 8358, 16268, 225, 115, 43357,
22174, 1556, 778, 25792, 83, 50256
START_OF_TRANSCRIPT, EN_CODE, TRANSCRIBE, NOTIMESTAMPS, 7057, 0, 1012, 366, 291,
2633, 30, 508, 6, 1301, 287, 6, 36107, 631, 220, 11178,
115, 15567, 871, 44393, END_OF_TRANSCRIPT
]
EXPECTED_SPECIAL_TEXT = (
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>Hey! How are you feeling? "
"J'ai l'impression que 郷さん est prêt<|endoftext|>"
)
# fmt: on
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
self.assertEqual(text, multilingual_tokenizer.decode(multilingual_tokens))
special_transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=False)
self.assertEqual(special_transcript, EXPECTED_SPECIAL_TEXT)
transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=True)
EXPECTED_JAP = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
self.assertEqual(transcript, EXPECTED_JAP)
self.assertEqual(transcript, text)
def test_vocab_size(self):
self.assertEqual(self.tokenizer.vocab_size, 50257)
# Copied from transformers.tests.speech_to_test.test_tokenization_speech_to_text.py
def test_tokenizer_decode_ignores_language_codes(self):
self.assertIn(ES_CODE, self.tokenizer.all_special_ids)
generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2]
@@ -176,15 +182,48 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
self.assertNotIn(self.tokenizer.eos_token, result)
def test_batch_encoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
batch = ["<|en|><|notimestamps|>", "<|en|><|notimestamps|>I am sure that"]
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
"openai/whisper-tiny", language="spanish", task="translate"
)
batch = ["El gato ", "El gato se sentó"]
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
# fmt: off
EXPECTED_MULTI = [
[50258, 50362, 50256, 50256, 50256, 50256],
[50258, 50362, 40, 716, 1654, 326]
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 220,
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 369,
2279, 812, END_OF_TRANSCRIPT]
]
# fmt: on
self.assertListEqual(batch_output, EXPECTED_MULTI)
def test_set_prefix_tokens(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
"openai/whisper-tiny", language="spanish", task="translate"
)
# change the language prefix token from Spanish to English
multilingual_tokenizer.set_prefix_tokens(language="english")
batch = ["the cat", "the cat sat"]
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
# fmt: off
EXPECTED_MULTI = [
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
3227, END_OF_TRANSCRIPT]
]
# fmt: on
self.assertListEqual(batch_output, EXPECTED_MULTI)
def test_batch_encoding_decoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
batch = ["hola güey", "que onda"]
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
self.assertListEqual(batch, transcription)