Whisper tokenizer word level timestamps (#32197)
* fix _fix_key in PreTrainedModel * fix _find_longest_common_sequence * add test * remove result.json * nit * update test
This commit is contained in:
@@ -1174,7 +1174,22 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
|
|||||||
"There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference."
|
"There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if token_timestamp_sequences:
|
||||||
|
# Get length of longest subsequence of tokens that match
|
||||||
|
# and have timestamps that are in order
|
||||||
|
matches = sum(
|
||||||
|
1
|
||||||
|
for idx, elem in enumerate(left)
|
||||||
|
if (
|
||||||
|
elem == right[idx]
|
||||||
|
and left_token_timestamp_sequence[left_start + idx]
|
||||||
|
<= token_timestamp_sequences[seq_idx + 1][right_start + idx]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
matches = np.sum(left == right)
|
matches = np.sum(left == right)
|
||||||
|
|
||||||
matching = matches / i + eps
|
matching = matches / i + eps
|
||||||
if matches > 1 and matching > max_:
|
if matches > 1 and matching > max_:
|
||||||
max_ = matching
|
max_ = matching
|
||||||
|
|||||||
@@ -338,6 +338,42 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
|
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
|
||||||
|
|
||||||
|
def test_decode_asr_with_word_level_timestamps(self):
|
||||||
|
# fmt: off
|
||||||
|
model_outputs = [
|
||||||
|
{
|
||||||
|
'stride': [10, 0, 5],
|
||||||
|
'tokens': np.array([[ 50257, 50362, 3363, 11, 345, 460, 0, 2329, 466, 340, 0, 50256 ]]),
|
||||||
|
'token_timestamps': np.array([[ 0, 0, 5.18, 5.56, 5.56, 5.84, 6.36, 7.12, 7.54, 7.82, 8.16, 9.48 ]])
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'stride': [10, 5, 0],
|
||||||
|
'tokens': np.array([[ 50257, 50362, 2329, 466, 340, 0, 3363, 345, 460, 0, 2329, 466, 340, 50256 ]]),
|
||||||
|
'token_timestamps': np.array([[ 0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72 ]])
|
||||||
|
}
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
tokenizer = WhisperTokenizer.from_pretrained("onnx-community/whisper-tiny.en_timestamped")
|
||||||
|
result = tokenizer._decode_asr(
|
||||||
|
model_outputs, return_timestamps="word", return_language=False, time_precision=0.02
|
||||||
|
)
|
||||||
|
|
||||||
|
EXPECTED_OUTPUT = (
|
||||||
|
" Yes, you can! Just do it",
|
||||||
|
{
|
||||||
|
"chunks": [
|
||||||
|
{"text": " Yes,", "timestamp": (5.18, 5.56)},
|
||||||
|
{"text": " you", "timestamp": (5.56, 5.84)},
|
||||||
|
{"text": " can!", "timestamp": (5.84, 7.12)},
|
||||||
|
{"text": " Just", "timestamp": (7.12, 7.56)},
|
||||||
|
{"text": " do", "timestamp": (7.56, 7.8)},
|
||||||
|
{"text": " it", "timestamp": (7.8, 8.72)},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(result, EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||||
checkpoint_name = "openai/whisper-small.en"
|
checkpoint_name = "openai/whisper-small.en"
|
||||||
|
|||||||
Reference in New Issue
Block a user