From c9cf33777285681c62c0fc12a4d0afb50c82a9dc Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 2 Jun 2023 15:26:59 +0100 Subject: [PATCH] [Whisper Tokenizer] Skip special tokens when decoding with timestamps (#23945) --- .../models/whisper/tokenization_whisper.py | 10 ++++-- .../whisper/tokenization_whisper_fast.py | 10 ++++-- .../whisper/test_tokenization_whisper.py | 32 +++++++++++++++++++ 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 428254a26a..b62d9cbcb7 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -491,7 +491,7 @@ class WhisperTokenizer(PreTrainedTokenizer): normalizer = EnglishTextNormalizer(self.english_spelling_normalizer) return normalizer(text) - def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str: + def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str: """ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". @@ -505,7 +505,9 @@ class WhisperTokenizer(PreTrainedTokenizer): outputs.append([]) else: outputs[-1].append(token) - outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs] + outputs = [ + s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs + ] return "".join(outputs) def _compute_offsets(self, token_ids, time_precision=0.02): @@ -593,7 +595,9 @@ class WhisperTokenizer(PreTrainedTokenizer): **kwargs, ) if decode_with_timestamps: - text = self._decode_with_timestamps(token_ids, time_precision=time_precision) + text = self._decode_with_timestamps( + token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens + ) # retrieve offsets if output_offsets: offsets = None diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index a31fe00056..642be2e4cb 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -199,7 +199,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): return super()._encode_plus(*args, **kwargs) # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps - def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str: + def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str: """ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". @@ -213,7 +213,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): outputs.append([]) else: outputs[-1].append(token) - outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs] + outputs = [ + s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs + ] return "".join(outputs) # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets @@ -303,7 +305,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) if decode_with_timestamps: - text = self._decode_with_timestamps(token_ids, time_precision=time_precision) + text = self._decode_with_timestamps( + token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens + ) # retrieve offsets if output_offsets: offsets = None diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 09c98db317..aea86525f5 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -213,6 +213,38 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens ) + def test_skip_special_tokens_with_timestamps(self): + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + # fmt: off + encoded_input = [ + 50258, 50363, 50364, 634, 575, 12525, 22618, 1968, 6144, + 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, + 439, 293, 50676, 50676, 393, 4411, 294, 309, 457, + 707, 295, 33301, 286, 392, 6628, 13, 50836, 50257, + ] + # fmt: on + + expected_with_special_tokens = "<|startoftranscript|><|notimestamps|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and<|6.24|><|6.24|> can discover in it but little of rocky Ithaca.<|9.44|><|endoftext|>" + expected_without_special_tokens = "<|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and<|6.24|><|6.24|> can discover in it but little of rocky Ithaca.<|9.44|>" + self.assertEqual( + tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=False), + expected_with_special_tokens, + ) + self.assertEqual( + tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=True), + expected_without_special_tokens, + ) + self.assertEqual( + rust_tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=False), + expected_with_special_tokens, + ) + self.assertEqual( + rust_tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=True), + expected_without_special_tokens, + ) + def test_fast_tokenizer_get_prompt_ids(self): tokenizer = self.get_tokenizer() rust_tokenizer = self.get_rust_tokenizer()