[Whisper] Fix whisper tokenizer (#34537)
* handle single timestamp ending * include last timestamp token * handle single timestamp ending * avoid floating points arithm limitations * ensure float64 operations * new test * make fixup * make copies * handle edge case double tokens ending with different tokens * handle single timestamp ending * make fixup * handle conditioning on prev segments * fix * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * [run-slow] whisper * don't call item() to avoid unnecessary sync * fix --------- Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Co-authored-by: Eustache Le Bihan <eustlb@users.noreply.huggingface.co>
This commit is contained in:
@@ -308,6 +308,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
num_segment_frames: Optional[int] = None,
|
num_segment_frames: Optional[int] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
time_precision: float = 0.02,
|
time_precision: float = 0.02,
|
||||||
|
time_precision_features: float = 0.01,
|
||||||
return_token_timestamps: Optional[bool] = None,
|
return_token_timestamps: Optional[bool] = None,
|
||||||
return_segments: bool = False,
|
return_segments: bool = False,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
@@ -417,6 +418,8 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
time_precision (`int`, *optional*, defaults to 0.02):
|
time_precision (`int`, *optional*, defaults to 0.02):
|
||||||
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
|
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
|
||||||
for 20 ms.
|
for 20 ms.
|
||||||
|
time_precision_features (`int`, *optional*, defaults to 0.01):
|
||||||
|
The duration represented by a feature frame in seconds.
|
||||||
return_token_timestamps (`bool`, *optional*):
|
return_token_timestamps (`bool`, *optional*):
|
||||||
Whether to return token-level timestamps with the text. This can be used with or without the
|
Whether to return token-level timestamps with the text. This can be used with or without the
|
||||||
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
|
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
|
||||||
@@ -629,7 +632,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
cur_bsz=cur_bsz,
|
cur_bsz=cur_bsz,
|
||||||
batch_idx_map=batch_idx_map,
|
batch_idx_map=batch_idx_map,
|
||||||
)
|
)
|
||||||
time_offset = seek * time_precision / input_stride
|
time_offset = seek.to(torch.float64) * time_precision / input_stride
|
||||||
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
|
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
|
||||||
|
|
||||||
# 6.2 cut out next 30s segment from input features
|
# 6.2 cut out next 30s segment from input features
|
||||||
@@ -658,6 +661,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
config=self.config,
|
config=self.config,
|
||||||
device=init_tokens.device,
|
device=init_tokens.device,
|
||||||
suppress_tokens=suppress_tokens,
|
suppress_tokens=suppress_tokens,
|
||||||
|
timestamp_begin=timestamp_begin,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -718,6 +722,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
timestamp_begin=timestamp_begin,
|
timestamp_begin=timestamp_begin,
|
||||||
seek_num_frames=seek_num_frames,
|
seek_num_frames=seek_num_frames,
|
||||||
time_precision=time_precision,
|
time_precision=time_precision,
|
||||||
|
time_precision_features=time_precision_features,
|
||||||
input_stride=input_stride,
|
input_stride=input_stride,
|
||||||
prev_idx=prev_i,
|
prev_idx=prev_i,
|
||||||
idx=i,
|
idx=i,
|
||||||
@@ -1665,6 +1670,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
config,
|
config,
|
||||||
device,
|
device,
|
||||||
suppress_tokens,
|
suppress_tokens,
|
||||||
|
timestamp_begin,
|
||||||
kwargs,
|
kwargs,
|
||||||
):
|
):
|
||||||
if "decoder_input_ids" in kwargs:
|
if "decoder_input_ids" in kwargs:
|
||||||
@@ -1684,6 +1690,14 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
|
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
|
||||||
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
|
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
|
||||||
|
|
||||||
|
for segments in active_segments:
|
||||||
|
for seg in segments:
|
||||||
|
if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin:
|
||||||
|
# the segment finishes with two timestamp tokens
|
||||||
|
# we need to ignore the last timestamp token
|
||||||
|
# see https://github.com/huggingface/transformers/pull/34537
|
||||||
|
seg["tokens"] = seg["tokens"][:-1]
|
||||||
|
|
||||||
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
|
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
|
||||||
prev_ids = prompt_ids
|
prev_ids = prompt_ids
|
||||||
else:
|
else:
|
||||||
@@ -1778,6 +1792,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
timestamp_begin,
|
timestamp_begin,
|
||||||
seek_num_frames,
|
seek_num_frames,
|
||||||
time_precision,
|
time_precision,
|
||||||
|
time_precision_features,
|
||||||
input_stride,
|
input_stride,
|
||||||
prev_idx,
|
prev_idx,
|
||||||
idx,
|
idx,
|
||||||
@@ -1799,17 +1814,22 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
segments = []
|
segments = []
|
||||||
if single_timestamp_ending:
|
if single_timestamp_ending:
|
||||||
slices.append(len(seek_sequence))
|
slices.append(len(seek_sequence))
|
||||||
|
else:
|
||||||
|
# we want to include the last timestamp token in the last segment to know it was no single ending
|
||||||
|
slices[-1] += 1
|
||||||
|
|
||||||
last_slice = 0
|
last_slice = 0
|
||||||
# Add each segment to list of all segments
|
# Add each segment to list of all segments
|
||||||
for current_slice in slices:
|
for i, current_slice in enumerate(slices):
|
||||||
|
is_last_slice = i == len(slices) - 1
|
||||||
sliced_tokens = seek_sequence[last_slice:current_slice]
|
sliced_tokens = seek_sequence[last_slice:current_slice]
|
||||||
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
|
start_timestamp_pos = sliced_tokens[0] - timestamp_begin
|
||||||
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
|
idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2
|
||||||
|
end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
|
||||||
segments.append(
|
segments.append(
|
||||||
{
|
{
|
||||||
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
|
"start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision,
|
||||||
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
|
"end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision,
|
||||||
"tokens": sliced_tokens,
|
"tokens": sliced_tokens,
|
||||||
"result": seek_outputs[idx],
|
"result": seek_outputs[idx],
|
||||||
}
|
}
|
||||||
@@ -1827,16 +1847,16 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||||
# here we throw away all predictions after the last predicted "end of segment"
|
# here we throw away all predictions after the last predicted "end of segment"
|
||||||
# since we are cutting right in the middle of an audio
|
# since we are cutting right in the middle of an audio
|
||||||
last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin
|
last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin
|
||||||
segment_offset = last_timestamp_pos * input_stride
|
segment_offset = last_timestamp_pos * input_stride
|
||||||
else:
|
else:
|
||||||
# If whisper does not predict any "end of segment" token, then
|
# If whisper does not predict any "end of segment" token, then
|
||||||
# the whole decoding is considered a segment and we add it to the list of segments
|
# the whole decoding is considered a segment and we add it to the list of segments
|
||||||
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
|
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
|
||||||
last_timestamp_pos = seek_num_frames[prev_idx]
|
last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
|
||||||
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
|
if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
|
||||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||||
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
|
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64)
|
||||||
segments = [
|
segments = [
|
||||||
{
|
{
|
||||||
"start": time_offset[prev_idx],
|
"start": time_offset[prev_idx],
|
||||||
|
|||||||
@@ -528,7 +528,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
|
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
|
||||||
return normalizer(text)
|
return normalizer(text)
|
||||||
|
|
||||||
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
|
def _decode_with_timestamps(
|
||||||
|
self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
||||||
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||||
@@ -538,15 +540,25 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
cur_max_timestamp = 0.0
|
cur_max_timestamp = 0.0
|
||||||
prev_segments_len = 0.0
|
prev_segments_len = 0.0
|
||||||
|
penultimate_timestamp = 0.0
|
||||||
|
|
||||||
for token in token_ids:
|
for i, token in enumerate(token_ids):
|
||||||
if token >= timestamp_begin:
|
if token >= timestamp_begin:
|
||||||
timestamp = float((token - timestamp_begin) * time_precision)
|
timestamp = float((token - timestamp_begin) * time_precision)
|
||||||
|
|
||||||
if timestamp < cur_max_timestamp:
|
if timestamp < cur_max_timestamp:
|
||||||
# next segment has started
|
# next segment has started
|
||||||
prev_segments_len += cur_max_timestamp
|
last_was_single_ending = i >= 2 and not (
|
||||||
|
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
|
||||||
|
)
|
||||||
|
if last_was_single_ending:
|
||||||
|
prev_segments_len += time_precision * segment_size
|
||||||
|
else:
|
||||||
|
cur_max_timestamp = penultimate_timestamp
|
||||||
|
prev_segments_len += penultimate_timestamp
|
||||||
|
outputs = outputs[:-2]
|
||||||
|
|
||||||
|
penultimate_timestamp = cur_max_timestamp
|
||||||
cur_max_timestamp = timestamp
|
cur_max_timestamp = timestamp
|
||||||
|
|
||||||
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
|
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
|
||||||
@@ -558,7 +570,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
]
|
]
|
||||||
return "".join(outputs)
|
return "".join(outputs)
|
||||||
|
|
||||||
def _compute_offsets(self, token_ids, time_precision=0.02):
|
def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500):
|
||||||
"""
|
"""
|
||||||
Compute offsets for a given tokenized input
|
Compute offsets for a given tokenized input
|
||||||
|
|
||||||
@@ -567,6 +579,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
List of tokenized input ids. Can be obtained using the `__call__` method.
|
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||||
time_precision (`float`, *optional*, defaults to 0.02):
|
time_precision (`float`, *optional*, defaults to 0.02):
|
||||||
The time ratio to convert from token to time.
|
The time ratio to convert from token to time.
|
||||||
|
segment_size (`int`, *optional*, defaults to 1500):
|
||||||
|
The number of features in the input mel spectrogram.
|
||||||
"""
|
"""
|
||||||
offsets = []
|
offsets = []
|
||||||
# ensure torch tensor of token ids is placed on cpu
|
# ensure torch tensor of token ids is placed on cpu
|
||||||
@@ -597,6 +611,12 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
if start_timestamp_position < cur_max_timestamp:
|
if start_timestamp_position < cur_max_timestamp:
|
||||||
# next segment has started
|
# next segment has started
|
||||||
|
is_single_ending = last_slice >= 2 and not (
|
||||||
|
token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin
|
||||||
|
)
|
||||||
|
if is_single_ending:
|
||||||
|
prev_segments_len += segment_size
|
||||||
|
else:
|
||||||
prev_segments_len += cur_max_timestamp
|
prev_segments_len += cur_max_timestamp
|
||||||
|
|
||||||
cur_max_timestamp = end_timestamp_position
|
cur_max_timestamp = end_timestamp_position
|
||||||
@@ -609,8 +629,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
{
|
{
|
||||||
"text": text,
|
"text": text,
|
||||||
"timestamp": (
|
"timestamp": (
|
||||||
(start_timestamp_position + prev_segments_len) * time_precision,
|
start_timestamp_position * time_precision + prev_segments_len * time_precision,
|
||||||
(end_timestamp_position + prev_segments_len) * time_precision,
|
end_timestamp_position * time_precision + prev_segments_len * time_precision,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -169,7 +169,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
return super()._encode_plus(*args, **kwargs)
|
return super()._encode_plus(*args, **kwargs)
|
||||||
|
|
||||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps
|
||||||
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
|
def _decode_with_timestamps(
|
||||||
|
self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
||||||
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||||
@@ -179,15 +181,25 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
|
|
||||||
cur_max_timestamp = 0.0
|
cur_max_timestamp = 0.0
|
||||||
prev_segments_len = 0.0
|
prev_segments_len = 0.0
|
||||||
|
penultimate_timestamp = 0.0
|
||||||
|
|
||||||
for token in token_ids:
|
for i, token in enumerate(token_ids):
|
||||||
if token >= timestamp_begin:
|
if token >= timestamp_begin:
|
||||||
timestamp = float((token - timestamp_begin) * time_precision)
|
timestamp = float((token - timestamp_begin) * time_precision)
|
||||||
|
|
||||||
if timestamp < cur_max_timestamp:
|
if timestamp < cur_max_timestamp:
|
||||||
# next segment has started
|
# next segment has started
|
||||||
prev_segments_len += cur_max_timestamp
|
last_was_single_ending = i >= 2 and not (
|
||||||
|
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
|
||||||
|
)
|
||||||
|
if last_was_single_ending:
|
||||||
|
prev_segments_len += time_precision * segment_size
|
||||||
|
else:
|
||||||
|
cur_max_timestamp = penultimate_timestamp
|
||||||
|
prev_segments_len += penultimate_timestamp
|
||||||
|
outputs = outputs[:-2]
|
||||||
|
|
||||||
|
penultimate_timestamp = cur_max_timestamp
|
||||||
cur_max_timestamp = timestamp
|
cur_max_timestamp = timestamp
|
||||||
|
|
||||||
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
|
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
|
||||||
@@ -200,7 +212,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
return "".join(outputs)
|
return "".join(outputs)
|
||||||
|
|
||||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets
|
||||||
def _compute_offsets(self, token_ids, time_precision=0.02):
|
def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500):
|
||||||
"""
|
"""
|
||||||
Compute offsets for a given tokenized input
|
Compute offsets for a given tokenized input
|
||||||
|
|
||||||
@@ -209,6 +221,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
List of tokenized input ids. Can be obtained using the `__call__` method.
|
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||||
time_precision (`float`, *optional*, defaults to 0.02):
|
time_precision (`float`, *optional*, defaults to 0.02):
|
||||||
The time ratio to convert from token to time.
|
The time ratio to convert from token to time.
|
||||||
|
segment_size (`int`, *optional*, defaults to 1500):
|
||||||
|
The number of features in the input mel spectrogram.
|
||||||
"""
|
"""
|
||||||
offsets = []
|
offsets = []
|
||||||
# ensure torch tensor of token ids is placed on cpu
|
# ensure torch tensor of token ids is placed on cpu
|
||||||
@@ -239,6 +253,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
|
|
||||||
if start_timestamp_position < cur_max_timestamp:
|
if start_timestamp_position < cur_max_timestamp:
|
||||||
# next segment has started
|
# next segment has started
|
||||||
|
is_single_ending = last_slice >= 2 and not (
|
||||||
|
token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin
|
||||||
|
)
|
||||||
|
if is_single_ending:
|
||||||
|
prev_segments_len += segment_size
|
||||||
|
else:
|
||||||
prev_segments_len += cur_max_timestamp
|
prev_segments_len += cur_max_timestamp
|
||||||
|
|
||||||
cur_max_timestamp = end_timestamp_position
|
cur_max_timestamp = end_timestamp_position
|
||||||
@@ -251,8 +271,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
{
|
{
|
||||||
"text": text,
|
"text": text,
|
||||||
"timestamp": (
|
"timestamp": (
|
||||||
(start_timestamp_position + prev_segments_len) * time_precision,
|
start_timestamp_position * time_precision + prev_segments_len * time_precision,
|
||||||
(end_timestamp_position + prev_segments_len) * time_precision,
|
end_timestamp_position * time_precision + prev_segments_len * time_precision,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2096,6 +2096,94 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
|
transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
|
||||||
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)
|
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_small_longform_timestamps_generation(self):
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-small.en")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
||||||
|
sample = dataset[0]["audio"]["array"]
|
||||||
|
sampling_rate = dataset[0]["audio"]["sampling_rate"]
|
||||||
|
|
||||||
|
sample = [*sample[: 15 * sampling_rate], *np.zeros(16 * sampling_rate).tolist(), *sample[15 * sampling_rate :]]
|
||||||
|
sample = np.array(sample)
|
||||||
|
|
||||||
|
input_features = processor(
|
||||||
|
sample,
|
||||||
|
sampling_rate=16_000,
|
||||||
|
padding="longest",
|
||||||
|
truncation=False,
|
||||||
|
return_attention_mask=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_features
|
||||||
|
|
||||||
|
input_features = input_features.to(torch_device)
|
||||||
|
generated_ids = model.generate(input_features, return_timestamps=True, return_segments=True)
|
||||||
|
|
||||||
|
EXPECTED_TRANSCRIPT = [
|
||||||
|
{
|
||||||
|
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
|
||||||
|
"timestamp": (0.0, 6.38),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||||
|
"timestamp": (6.38, 11.32),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " He tells us that at this festive season of the year,",
|
||||||
|
"timestamp": (11.32, 15.0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " With Christmas and roast beef looming before us, similes drawn from eating and its results",
|
||||||
|
"timestamp": (30.0, 36.76),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " occur most readily to the mind.",
|
||||||
|
"timestamp": (36.76, 39.80),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and",
|
||||||
|
"timestamp": (39.80, 45.36),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " can discover in it but little of rocky Ithaca.",
|
||||||
|
"timestamp": (45.36, 49.0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles",
|
||||||
|
"timestamp": (49.0, 56.28),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " are as national as a jingo poem. Mr. Burkett fosters landscape's smile at one much in",
|
||||||
|
"timestamp": (56.28, 64.12),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " the same way that Mr. Karker used to flash his teeth. And Mr. John Collier gives his",
|
||||||
|
"timestamp": (64.12, 70.76),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " sitter a cheerful slap on the back before he says, like a shampoo or in a Turkish bath,",
|
||||||
|
"timestamp": (70.76, 77.16),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": " Next Man",
|
||||||
|
"timestamp": (77.16, 78.16),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
|
||||||
|
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
|
transcript_segments = [
|
||||||
|
{
|
||||||
|
"text": processor.decode(seg["tokens"], skip_special_tokens=True),
|
||||||
|
"timestamp": (seg["start"].item(), seg["end"].item()),
|
||||||
|
}
|
||||||
|
for seg in generated_ids["segments"][0]
|
||||||
|
]
|
||||||
|
self.assertEqual(transcript_segments, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_large_timestamp_generation(self):
|
def test_large_timestamp_generation(self):
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user