[whisper] alternative fix for long-form timestamps (#32131)
* [whisper] alternative fix for long-form timestamps * update test
This commit is contained in:
@@ -587,11 +587,20 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
|
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
|
||||||
|
|
||||||
last_slice = np.where(timestamp_tokens)[0][0]
|
last_slice = np.where(timestamp_tokens)[0][0]
|
||||||
|
cur_max_timestamp = 0
|
||||||
|
prev_segments_len = 0
|
||||||
for current_slice in consecutive:
|
for current_slice in consecutive:
|
||||||
sliced_tokens = token_ids[last_slice:current_slice]
|
sliced_tokens = token_ids[last_slice:current_slice]
|
||||||
if len(sliced_tokens) > 1:
|
if len(sliced_tokens) > 1:
|
||||||
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
||||||
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
||||||
|
|
||||||
|
if start_timestamp_position < cur_max_timestamp:
|
||||||
|
# next segment has started
|
||||||
|
prev_segments_len += cur_max_timestamp
|
||||||
|
|
||||||
|
cur_max_timestamp = end_timestamp_position
|
||||||
|
|
||||||
# strip timestamp tokens from the text output
|
# strip timestamp tokens from the text output
|
||||||
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
|
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
|
||||||
text = self._decode(sliced_tokens)
|
text = self._decode(sliced_tokens)
|
||||||
@@ -600,8 +609,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
{
|
{
|
||||||
"text": text,
|
"text": text,
|
||||||
"timestamp": (
|
"timestamp": (
|
||||||
start_timestamp_position * time_precision,
|
(start_timestamp_position + prev_segments_len) * time_precision,
|
||||||
end_timestamp_position * time_precision,
|
(end_timestamp_position + prev_segments_len) * time_precision,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -229,11 +229,20 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
|
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
|
||||||
|
|
||||||
last_slice = np.where(timestamp_tokens)[0][0]
|
last_slice = np.where(timestamp_tokens)[0][0]
|
||||||
|
cur_max_timestamp = 0
|
||||||
|
prev_segments_len = 0
|
||||||
for current_slice in consecutive:
|
for current_slice in consecutive:
|
||||||
sliced_tokens = token_ids[last_slice:current_slice]
|
sliced_tokens = token_ids[last_slice:current_slice]
|
||||||
if len(sliced_tokens) > 1:
|
if len(sliced_tokens) > 1:
|
||||||
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
||||||
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
||||||
|
|
||||||
|
if start_timestamp_position < cur_max_timestamp:
|
||||||
|
# next segment has started
|
||||||
|
prev_segments_len += cur_max_timestamp
|
||||||
|
|
||||||
|
cur_max_timestamp = end_timestamp_position
|
||||||
|
|
||||||
# strip timestamp tokens from the text output
|
# strip timestamp tokens from the text output
|
||||||
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
|
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
|
||||||
text = self._decode(sliced_tokens)
|
text = self._decode(sliced_tokens)
|
||||||
@@ -242,8 +251,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
{
|
{
|
||||||
"text": text,
|
"text": text,
|
||||||
"timestamp": (
|
"timestamp": (
|
||||||
start_timestamp_position * time_precision,
|
(start_timestamp_position + prev_segments_len) * time_precision,
|
||||||
end_timestamp_position * time_precision,
|
(end_timestamp_position + prev_segments_len) * time_precision,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2099,6 +2099,65 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_tiny_longform_timestamps_generation(self):
|
||||||
|
set_seed(0)
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
||||||
|
sample = dataset[0]["audio"]
|
||||||
|
|
||||||
|
input_features = processor(
|
||||||
|
sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"]
|
||||||
|
)
|
||||||
|
input_features = input_features.to(torch_device)
|
||||||
|
|
||||||
|
generated_ids = model.generate(**input_features, return_timestamps=True, return_segments=True)
|
||||||
|
|
||||||
|
EXPECTED_TRANSCRIPT = [
|
||||||
|
{
|
||||||
|
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
|
||||||
|
"timestamp": (0.0, 6.5600000000000005),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||||
|
"timestamp": (6.5600000000000005, 11.24),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " He tells us that at this festive season of the year, with Christmas and roast beef looming",
|
||||||
|
"timestamp": (11.24, 16.88),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " before us, similarly drawn from eating and its results occur most readily to the mind.",
|
||||||
|
"timestamp": (16.88, 23.76),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and",
|
||||||
|
"timestamp": (23.76, 29.44),
|
||||||
|
},
|
||||||
|
{"text": " can discover in it but little of rocky ithaka.", "timestamp": (29.44, 33.72)},
|
||||||
|
{
|
||||||
|
"text": " Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite itals",
|
||||||
|
"timestamp": (33.72, 40.32),
|
||||||
|
},
|
||||||
|
{"text": " are as national as a jingo poem.", "timestamp": (40.32, 44.72)},
|
||||||
|
{
|
||||||
|
"text": " Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used",
|
||||||
|
"timestamp": (44.72, 50.4),
|
||||||
|
},
|
||||||
|
{"text": " to flash his teeth.", "timestamp": (50.4, 52.96)},
|
||||||
|
{
|
||||||
|
"text": " And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like",
|
||||||
|
"timestamp": (52.96, 58.68),
|
||||||
|
},
|
||||||
|
{"text": " a shampoo and a Turkish bath next man.", "timestamp": (58.68, 61.96)},
|
||||||
|
]
|
||||||
|
|
||||||
|
transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
|
||||||
|
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_large_timestamp_generation(self):
|
def test_large_timestamp_generation(self):
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user