[Whisper Tokenizer] Skip special tokens when decoding with timestamps (#23945)
This commit is contained in:
@@ -491,7 +491,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
||||||
return normalizer(text)
|
return normalizer(text)
|
||||||
|
|
||||||
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str:
|
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
|
||||||
"""
|
"""
|
||||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
||||||
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||||
@@ -505,7 +505,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
outputs.append([])
|
outputs.append([])
|
||||||
else:
|
else:
|
||||||
outputs[-1].append(token)
|
outputs[-1].append(token)
|
||||||
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs]
|
outputs = [
|
||||||
|
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
|
||||||
|
]
|
||||||
return "".join(outputs)
|
return "".join(outputs)
|
||||||
|
|
||||||
def _compute_offsets(self, token_ids, time_precision=0.02):
|
def _compute_offsets(self, token_ids, time_precision=0.02):
|
||||||
@@ -593,7 +595,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if decode_with_timestamps:
|
if decode_with_timestamps:
|
||||||
text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
|
text = self._decode_with_timestamps(
|
||||||
|
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
|
||||||
|
)
|
||||||
# retrieve offsets
|
# retrieve offsets
|
||||||
if output_offsets:
|
if output_offsets:
|
||||||
offsets = None
|
offsets = None
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
return super()._encode_plus(*args, **kwargs)
|
return super()._encode_plus(*args, **kwargs)
|
||||||
|
|
||||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps
|
||||||
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str:
|
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
|
||||||
"""
|
"""
|
||||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
||||||
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||||
@@ -213,7 +213,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
outputs.append([])
|
outputs.append([])
|
||||||
else:
|
else:
|
||||||
outputs[-1].append(token)
|
outputs[-1].append(token)
|
||||||
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs]
|
outputs = [
|
||||||
|
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
|
||||||
|
]
|
||||||
return "".join(outputs)
|
return "".join(outputs)
|
||||||
|
|
||||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets
|
||||||
@@ -303,7 +305,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if decode_with_timestamps:
|
if decode_with_timestamps:
|
||||||
text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
|
text = self._decode_with_timestamps(
|
||||||
|
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
|
||||||
|
)
|
||||||
# retrieve offsets
|
# retrieve offsets
|
||||||
if output_offsets:
|
if output_offsets:
|
||||||
offsets = None
|
offsets = None
|
||||||
|
|||||||
@@ -213,6 +213,38 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
|
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_skip_special_tokens_with_timestamps(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer()
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
encoded_input = [
|
||||||
|
50258, 50363, 50364, 634, 575, 12525, 22618, 1968, 6144,
|
||||||
|
35617, 20084, 1756, 311, 589, 307, 534, 10281, 934,
|
||||||
|
439, 293, 50676, 50676, 393, 4411, 294, 309, 457,
|
||||||
|
707, 295, 33301, 286, 392, 6628, 13, 50836, 50257,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
expected_with_special_tokens = "<|startoftranscript|><|notimestamps|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and<|6.24|><|6.24|> can discover in it but little of rocky Ithaca.<|9.44|><|endoftext|>"
|
||||||
|
expected_without_special_tokens = "<|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and<|6.24|><|6.24|> can discover in it but little of rocky Ithaca.<|9.44|>"
|
||||||
|
self.assertEqual(
|
||||||
|
tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=False),
|
||||||
|
expected_with_special_tokens,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=True),
|
||||||
|
expected_without_special_tokens,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
rust_tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=False),
|
||||||
|
expected_with_special_tokens,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
rust_tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=True),
|
||||||
|
expected_without_special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
def test_fast_tokenizer_get_prompt_ids(self):
|
def test_fast_tokenizer_get_prompt_ids(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
rust_tokenizer = self.get_rust_tokenizer()
|
rust_tokenizer = self.get_rust_tokenizer()
|
||||||
|
|||||||
Reference in New Issue
Block a user