[Whisper Tokenizer] Encode timestamps (#26054)
* [Whisper Tokenizer] Fix tests after adding timestamps * fix s2t tokenizer tests * fix vocab test * backwards comp * fix tests * comment * style * fix last test * fix fast * make faster * move logic to decode * remove skip test * fix decode with offsets * fix special tokens * empty commit to re-trigger ci * use lru cache
This commit is contained in:
@@ -15,6 +15,7 @@
|
|||||||
"""Tokenization classes for Whisper."""
|
"""Tokenization classes for Whisper."""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -546,6 +547,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
if len(sliced_tokens) > 1:
|
if len(sliced_tokens) > 1:
|
||||||
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
||||||
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
||||||
|
# strip timestamp tokens from the text output
|
||||||
|
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
|
||||||
offsets.append(
|
offsets.append(
|
||||||
{
|
{
|
||||||
"text": self._decode(sliced_tokens),
|
"text": self._decode(sliced_tokens),
|
||||||
@@ -559,6 +562,47 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
return offsets
|
return offsets
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def timestamp_ids(self, time_precision=0.02):
|
||||||
|
"""
|
||||||
|
Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_precision (`float`, `optional`, defaults to 0.02):
|
||||||
|
The time ratio to convert from token to time.
|
||||||
|
"""
|
||||||
|
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
|
||||||
|
|
||||||
|
def _preprocess_token_ids(
|
||||||
|
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||||
|
List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer.
|
||||||
|
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
|
||||||
|
removed.
|
||||||
|
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
|
||||||
|
filtered out from the token ids.
|
||||||
|
time_precision (`float`, `optional`, defaults to 0.02):
|
||||||
|
The time ratio to convert from token to time.
|
||||||
|
"""
|
||||||
|
if skip_special_tokens:
|
||||||
|
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
|
||||||
|
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
|
||||||
|
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
|
||||||
|
|
||||||
|
if not decode_with_timestamps:
|
||||||
|
# filter timestamp tokens if they are contained in the vocab
|
||||||
|
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
|
||||||
|
token_ids = [token for token in token_ids if token not in timestamp_ids]
|
||||||
|
|
||||||
|
return token_ids
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
token_ids,
|
token_ids,
|
||||||
@@ -593,33 +637,40 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
Returns:
|
Returns:
|
||||||
`str`: The decoded sentence.
|
`str`: The decoded sentence.
|
||||||
"""
|
"""
|
||||||
text = super().decode(
|
filtered_ids = self._preprocess_token_ids(
|
||||||
token_ids,
|
token_ids,
|
||||||
skip_special_tokens=skip_special_tokens,
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
decode_with_timestamps=decode_with_timestamps,
|
||||||
|
time_precision=time_precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = super().decode(
|
||||||
|
filtered_ids,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
decode_with_timestamps=decode_with_timestamps,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if decode_with_timestamps:
|
if decode_with_timestamps:
|
||||||
|
# legacy method to decode timestamps when not included in the tokenizer vocabulary
|
||||||
text = self._decode_with_timestamps(
|
text = self._decode_with_timestamps(
|
||||||
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
|
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
|
||||||
)
|
)
|
||||||
# retrieve offsets
|
# retrieve offsets
|
||||||
if output_offsets:
|
if output_offsets:
|
||||||
offsets = None
|
|
||||||
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
|
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
|
||||||
return {"text": text, "offsets": offsets}
|
return {"text": text, "offsets": offsets}
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, normalize: bool = False, **kwargs
|
self,
|
||||||
|
token_ids: Union[int, List[int]],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
normalize: bool = False,
|
||||||
|
decode_with_timestamps: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||||
|
|
||||||
if skip_special_tokens:
|
|
||||||
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
|
|
||||||
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
|
|
||||||
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
|
|
||||||
|
|
||||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
# To avoid mixing byte-level and unicode for byte-level BPT
|
# To avoid mixing byte-level and unicode for byte-level BPT
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
"""Tokenization classes for Whisper."""
|
"""Tokenization classes for Whisper."""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -255,6 +256,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
if len(sliced_tokens) > 1:
|
if len(sliced_tokens) > 1:
|
||||||
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
||||||
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
||||||
|
# strip timestamp tokens from the text output
|
||||||
|
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
|
||||||
offsets.append(
|
offsets.append(
|
||||||
{
|
{
|
||||||
"text": self._decode(sliced_tokens),
|
"text": self._decode(sliced_tokens),
|
||||||
@@ -268,6 +271,49 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
|
|
||||||
return offsets
|
return offsets
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.timestamp_ids
|
||||||
|
def timestamp_ids(self, time_precision=0.02):
|
||||||
|
"""
|
||||||
|
Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_precision (`float`, `optional`, defaults to 0.02):
|
||||||
|
The time ratio to convert from token to time.
|
||||||
|
"""
|
||||||
|
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
|
||||||
|
|
||||||
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
|
||||||
|
def _preprocess_token_ids(
|
||||||
|
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||||
|
List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer.
|
||||||
|
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
|
||||||
|
removed.
|
||||||
|
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
|
||||||
|
filtered out from the token ids.
|
||||||
|
time_precision (`float`, `optional`, defaults to 0.02):
|
||||||
|
The time ratio to convert from token to time.
|
||||||
|
"""
|
||||||
|
if skip_special_tokens:
|
||||||
|
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
|
||||||
|
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
|
||||||
|
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
|
||||||
|
|
||||||
|
if not decode_with_timestamps:
|
||||||
|
# filter timestamp tokens if they are contained in the vocab
|
||||||
|
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
|
||||||
|
token_ids = [token for token in token_ids if token not in timestamp_ids]
|
||||||
|
|
||||||
|
return token_ids
|
||||||
|
|
||||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
@@ -303,29 +349,32 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
Returns:
|
Returns:
|
||||||
`str`: The decoded sentence.
|
`str`: The decoded sentence.
|
||||||
"""
|
"""
|
||||||
text = super().decode(
|
filtered_ids = self._preprocess_token_ids(
|
||||||
token_ids,
|
token_ids,
|
||||||
skip_special_tokens=skip_special_tokens,
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
decode_with_timestamps=decode_with_timestamps,
|
||||||
|
time_precision=time_precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = super().decode(
|
||||||
|
filtered_ids,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
decode_with_timestamps=decode_with_timestamps,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if decode_with_timestamps:
|
if decode_with_timestamps:
|
||||||
|
# legacy method to decode timestamps when not included in the tokenizer vocabulary
|
||||||
text = self._decode_with_timestamps(
|
text = self._decode_with_timestamps(
|
||||||
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
|
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
|
||||||
)
|
)
|
||||||
# retrieve offsets
|
# retrieve offsets
|
||||||
if output_offsets:
|
if output_offsets:
|
||||||
offsets = None
|
|
||||||
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
|
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
|
||||||
return {"text": text, "offsets": offsets}
|
return {"text": text, "offsets": offsets}
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
|
def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
|
||||||
if kwargs["skip_special_tokens"]:
|
|
||||||
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
|
|
||||||
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
|
|
||||||
kwargs["token_ids"] = self._strip_prompt(kwargs["token_ids"], prompt_token_id, decoder_start_token_id)
|
|
||||||
|
|
||||||
text = super()._decode(*args, **kwargs)
|
text = super()._decode(*args, **kwargs)
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
|
|||||||
@@ -52,14 +52,13 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
|
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
|
||||||
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
|
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
|
||||||
|
|
||||||
@unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time")
|
|
||||||
def test_get_vocab(self):
|
def test_get_vocab(self):
|
||||||
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
|
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
|
||||||
|
|
||||||
self.assertEqual(vocab_keys[0], "!")
|
self.assertEqual(vocab_keys[0], "!")
|
||||||
self.assertEqual(vocab_keys[1], '"')
|
self.assertEqual(vocab_keys[1], '"')
|
||||||
self.assertEqual(vocab_keys[-1], "<|notimestamps|>")
|
self.assertEqual(vocab_keys[-1], "<|30.00|>")
|
||||||
self.assertEqual(len(vocab_keys), 50364)
|
self.assertEqual(len(vocab_keys), 51865)
|
||||||
|
|
||||||
def test_vocab_size(self):
|
def test_vocab_size(self):
|
||||||
self.assertEqual(self.get_tokenizer().vocab_size, 50258)
|
self.assertEqual(self.get_tokenizer().vocab_size, 50258)
|
||||||
@@ -117,7 +116,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
|
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time")
|
|
||||||
def test_output_offsets(self):
|
def test_output_offsets(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612]
|
previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612]
|
||||||
@@ -400,7 +398,6 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
|||||||
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
|
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
|
||||||
self.assertListEqual(batch, transcription)
|
self.assertListEqual(batch, transcription)
|
||||||
|
|
||||||
@unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time")
|
|
||||||
def test_offset_decoding(self):
|
def test_offset_decoding(self):
|
||||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
|
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
|||||||
Reference in New Issue
Block a user