[Whisper Tokenizer] Make more user-friendly (#19921)
* [Whisper Tokenizer] Make more user-friendly * use property * make indexing rigorous * small clean-up * tests * skip seq2seq tests * remove multilingual arg * reorder args * collapse to one function Co-authored-by: ArthurZucker <arthur@huggingface.co> * option to override attributes Co-authored-by: ArthurZucker <arthur@huggingface.co> * add to docs * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make comment more clear Co-authored-by: sgugger <sylvain@huggingface.co> * don't add special tokens in get_decoder_prompt_ids * add test for set_prefix_tokens Co-authored-by: ArthurZucker <arthur@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: sgugger <sylvain@huggingface.co>
This commit is contained in:
@@ -20,14 +20,20 @@ from transformers.testing_utils import slow
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
EN_CODE = 50258
|
||||
ES_CODE = 50256
|
||||
ES_CODE = 50262
|
||||
EN_CODE = 50259
|
||||
END_OF_TRANSCRIPT = 50257
|
||||
START_OF_TRANSCRIPT = 50258
|
||||
TRANSLATE = 50358
|
||||
TRANSCRIBE = 50359
|
||||
NOTIMESTAMPS = 50363
|
||||
|
||||
|
||||
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = WhisperTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = False
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@@ -101,13 +107,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
||||
transcript = (
|
||||
"'<|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> Nor is Mr. Quilters manner less interesting"
|
||||
" than his matter.<|endoftext|>'"
|
||||
)
|
||||
clean_transcript = " Nor is Mr. Quilters manner less interesting than his matter."
|
||||
french_text = "Bonjour! Il me semble que Mrs Quilters n'était pas présente"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer: WhisperTokenizer = WhisperTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
@@ -115,15 +114,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
|
||||
def test_tokenizer_equivalence(self):
|
||||
text = "다람쥐 헌 쳇바퀴에 타고파"
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="ko")
|
||||
gpt2_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="korean")
|
||||
monolingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
||||
|
||||
gpt2_tokens = gpt2_tokenizer.encode(text)
|
||||
multilingual_tokens = multilingual_tokenizer.encode(text)
|
||||
monolingual_tokens = monolingual_tokenizer.encode(text, add_special_tokens=False)
|
||||
multilingual_tokens = multilingual_tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
||||
assert monolingual_tokenizer.decode(monolingual_tokens) == text
|
||||
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
||||
assert len(gpt2_tokens) > len(multilingual_tokens)
|
||||
assert len(monolingual_tokens) > len(multilingual_tokens)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_ENG = [
|
||||
@@ -138,35 +137,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(gpt2_tokens, EXPECTED_ENG)
|
||||
self.assertListEqual(monolingual_tokens, EXPECTED_ENG)
|
||||
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
|
||||
|
||||
def test_tokenizer_special(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
||||
text = "<|startoftranscript|>Hey! How are you feeling? J'ai l'impression que 郷さん est prêt<|endoftext|>"
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
|
||||
"openai/whisper-tiny", language="english", task="transcribe"
|
||||
)
|
||||
text = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
|
||||
|
||||
multilingual_tokens = multilingual_tokenizer.encode(text)
|
||||
|
||||
# fmt: off
|
||||
# format: <|startoftranscript|> <|lang-id|> <|task|> <|notimestamps|> ... transcription ids ... <|endoftext|>
|
||||
EXPECTED_MULTI = [
|
||||
50257, 10814, 0, 1374, 389, 345, 4203, 30, 449, 6,
|
||||
1872, 300, 6, 11011, 2234, 8358, 16268, 225, 115, 43357,
|
||||
22174, 1556, 778, 25792, 83, 50256
|
||||
START_OF_TRANSCRIPT, EN_CODE, TRANSCRIBE, NOTIMESTAMPS, 7057, 0, 1012, 366, 291,
|
||||
2633, 30, 508, 6, 1301, 287, 6, 36107, 631, 220, 11178,
|
||||
115, 15567, 871, 44393, END_OF_TRANSCRIPT
|
||||
]
|
||||
EXPECTED_SPECIAL_TEXT = (
|
||||
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>Hey! How are you feeling? "
|
||||
"J'ai l'impression que 郷さん est prêt<|endoftext|>"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
|
||||
|
||||
self.assertEqual(text, multilingual_tokenizer.decode(multilingual_tokens))
|
||||
special_transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=False)
|
||||
self.assertEqual(special_transcript, EXPECTED_SPECIAL_TEXT)
|
||||
|
||||
transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_JAP = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
|
||||
self.assertEqual(transcript, EXPECTED_JAP)
|
||||
self.assertEqual(transcript, text)
|
||||
|
||||
def test_vocab_size(self):
|
||||
self.assertEqual(self.tokenizer.vocab_size, 50257)
|
||||
|
||||
# Copied from transformers.tests.speech_to_test.test_tokenization_speech_to_text.py
|
||||
def test_tokenizer_decode_ignores_language_codes(self):
|
||||
self.assertIn(ES_CODE, self.tokenizer.all_special_ids)
|
||||
generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2]
|
||||
@@ -176,15 +182,48 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
self.assertNotIn(self.tokenizer.eos_token, result)
|
||||
|
||||
def test_batch_encoding(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
||||
batch = ["<|en|><|notimestamps|>", "<|en|><|notimestamps|>I am sure that"]
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
|
||||
"openai/whisper-tiny", language="spanish", task="translate"
|
||||
)
|
||||
batch = ["El gato ", "El gato se sentó"]
|
||||
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_MULTI = [
|
||||
[50258, 50362, 50256, 50256, 50256, 50256],
|
||||
[50258, 50362, 40, 716, 1654, 326]
|
||||
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 220,
|
||||
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
|
||||
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 369,
|
||||
2279, 812, END_OF_TRANSCRIPT]
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(batch_output, EXPECTED_MULTI)
|
||||
|
||||
def test_set_prefix_tokens(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
|
||||
"openai/whisper-tiny", language="spanish", task="translate"
|
||||
)
|
||||
|
||||
# change the language prefix token from Spanish to English
|
||||
multilingual_tokenizer.set_prefix_tokens(language="english")
|
||||
|
||||
batch = ["the cat", "the cat sat"]
|
||||
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_MULTI = [
|
||||
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
|
||||
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
|
||||
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
|
||||
3227, END_OF_TRANSCRIPT]
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(batch_output, EXPECTED_MULTI)
|
||||
|
||||
def test_batch_encoding_decoding(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
|
||||
batch = ["hola güey", "que onda"]
|
||||
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
|
||||
self.assertListEqual(batch, transcription)
|
||||
|
||||
Reference in New Issue
Block a user