Whisper Timestamp processor and prediction (#20620)

* add draft logit processor

* add template functions

* update timesapmt processor parameters

* draft script

* simplify code

* cleanup

* fixup and clean

* update pipeline

* style

* clean up previous idea

* add tokenization utils

* update tokenizer and asr output

* fit whisper type

* style and update test

* clean test

* style test

* update tests

* update error test

* udpate code (not based on review yet)

* update tokenization

* update asr pipeline

* update code

* cleanup and update test

* fmt

* remove text verificatino

* cleanup

* cleanup

* add model test

* update tests

* update code add docstring

* update code and add docstring

* fix pipeline tests

* add draft logit processor

add template functions

update timesapmt processor parameters

draft script

simplify code

cleanup

fixup and clean

update pipeline

style

clean up previous idea

add tokenization utils

update tokenizer and asr output

fit whisper type

style and update test

clean test

style test

update tests

update error test

udpate code (not based on review yet)

update tokenization

update asr pipeline

update code

cleanup and update test

fmt

remove text verificatino

cleanup

cleanup

add model test

update tests

update code add docstring

update code and add docstring

fix pipeline tests

* Small update.

* Fixup.

* Tmp.

* More support.

* Making `forced_decoder_ids` non mandatory for users to set.

* update and fix first bug

* properly process sequence right after merge if last

* tofo

* allow list inputs + compute begin index better

* start adding tests

* add the 3 edge cases

* style

* format sequences

* fixup

* update

* update

* style

* test passes, edge cases should be good

* update last value

* remove Trie

* update tests and expec ted values

* handle bigger chunk_length

* clean tests a bit

* refactor chunk iter and clean pipeline

* update tests

* style

* refactor chunk iter and clean pipeline

* upade

* resolve comments

* Apply suggestions from code review

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* take stride right into account

* update test expected values

* Update code based on review

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
Arthur
2023-01-17 15:50:09 +01:00
committed by GitHub
parent 25ddd91b24
commit bb300ac686
6 changed files with 737 additions and 40 deletions

View File

@@ -20,6 +20,8 @@ import os
import tempfile
import unittest
import numpy as np
from transformers import WhisperConfig
from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device
from transformers.utils import cached_property
@@ -44,6 +46,7 @@ if is_torch_available():
WhisperProcessor,
set_seed,
)
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
@@ -1030,7 +1033,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
@slow
def test_tiny_en_batched_generation(self):
torch_device = "cuda"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
@@ -1067,3 +1069,43 @@ class WhisperModelIntegrationTests(unittest.TestCase):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_timestamp_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = np.concatenate(self._load_datasamples(4))
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
model.config.forced_decoder_ids = [(1, 50259), (2, 50359), (3, 50364)]
timestamp_processor = [WhisperTimeStampLogitsProcessor(len(model.config.forced_decoder_ids))]
generated_ids = model.generate(input_features, max_length=448, logits_processor=timestamp_processor).to("cpu")
# fmt: off
EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404])
# fmt: on
self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))
# fmt: off
EXPECTED_TRANSCRIPT = [
{
'text': " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
'offsets': [
{'text': ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.', 'timestamp': (0.0, 5.62)},
{'text': " Nor is Mr. Quilter's manner less interesting than his matter.", 'timestamp': (5.62, 10.36)},
{'text': ' He tells us that at this festive season of the year,', 'timestamp': (10.36, 14.46)},
{'text': ' with Christmas and roast beef looming before us,', 'timestamp': (14.46, 17.76)},
{'text': ' similes drawn from eating and its results occur most readily to the mind.', 'timestamp': (17.76, 22.8)},
{'text': " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", 'timestamp': (22.8, 28.82)}
]
}
]
# fmt: on
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

View File

@@ -227,3 +227,71 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
self.assertListEqual(batch, transcription)
def test_offset_decoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
# fmt: off
INPUT_TOKENS = [
50258, 50259, 50359, 50364, 441, 1857, 4174, 11, 5242, 366,
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
293, 25730, 311, 454, 34152, 4496, 904, 50724, 50724, 366,
382, 4048, 382, 257, 361, 18459, 13065, 13, 2221, 13,
7145, 74, 325, 38756, 311, 29822, 7563, 412, 472, 709,
294, 264, 51122, 51122, 912, 636, 300, 2221, 13, 2741,
5767, 1143, 281, 7319, 702, 7798, 13, 400, 2221, 13,
2619, 4004, 811, 2709, 702, 51449, 51449, 50257
]
# fmt: on
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(
output,
[
{
"text": (
" Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles"
),
"timestamp": (0.0, 7.2),
},
{
"text": (
" are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the"
),
"timestamp": (7.2, 15.16),
},
{
"text": " same way that Mr. Carker used to flash his teeth. And Mr. John Colier gives his",
"timestamp": (15.16, 21.7),
},
],
)
# test a single sequence with timestamps
# fmt: off
INPUT_TOKENS = [
50364, 441, 1857, 4174, 11, 5242, 366,
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
293, 25730, 311, 454, 34152, 4496, 904, 50724
]
# fmt: on
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(
output[0],
{
"text": " Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles",
"timestamp": (0.0, 7.2),
},
)
# test a sequence without a single timestamps
# fmt: off
INPUT_TOKENS = [
441, 1857, 4174, 11, 5242, 366,
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
293, 25730, 311, 454, 34152, 4496, 904, 50724
]
# fmt: on
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])

View File

@@ -23,6 +23,7 @@ from transformers import (
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
AutoFeatureExtractor,
AutoProcessor,
AutoTokenizer,
Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC,
@@ -31,7 +32,7 @@ from transformers import (
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
from transformers.pipelines.automatic_speech_recognition import chunk_iter
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
from transformers.testing_utils import (
is_torch_available,
nested_simplify,
@@ -87,7 +88,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
if speech_recognizer.type == "ctc":
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
elif "Whisper" in speech_recognizer.model.__class__.__name__:
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
else:
# Non CTC models cannot use striding.
with self.assertRaises(ValueError):
@@ -117,6 +120,18 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)],
},
)
elif "Whisper" in speech_recognizer.model.__class__.__name__:
outputs = speech_recognizer(audio, return_timestamps=True)
self.assertIsInstance(outputs["chunks"], list)
nb_chunks = len(outputs["chunks"])
self.assertGreaterThan(nb_chunks, 0)
self.assertEqual(
outputs,
{
"text": ANY(str),
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(nb_chunks)],
},
)
else:
# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
@@ -142,7 +157,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual(output, {"text": "(Applaudissements)"})
# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
):
_ = speech_recognizer(waveform, return_timestamps="char")
@slow
@@ -290,6 +307,280 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
@slow
def test_find_longest_common_subsequence(self):
max_source_positions = 1500
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
previous_sequence = [[51492, 406, 3163, 1953, 466, 13, 51612, 51612]]
self.assertEqual(
processor.decode(previous_sequence[0], output_offsets=True),
{
"text": " not worth thinking about.",
"offsets": [{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}],
},
)
# Merge when the previous sequence is a suffix of the next sequence
# fmt: off
next_sequences_1 = [
[50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 50614, 50614, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
]
# fmt: on
self.assertEqual(
processor.decode(next_sequences_1[0], output_offsets=True),
{
"text": (
" of spectators, retrievality is not worth thinking about. His instant panic was followed by a"
" small, sharp blow high on his chest.<|endoftext|>"
),
"offsets": [
{"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (5.0, 9.4),
},
],
},
)
merge = _find_timestamp_sequence(
[[previous_sequence, (3000, 0, 0)], [next_sequences_1, (3000, 750, 0)]],
processor.tokenizer,
processor.feature_extractor,
max_source_positions,
)
# fmt: off
self.assertEqual(
merge,
[51492, 406, 3163, 1953, 466, 13, 51739, 51739, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51959],
)
# fmt: on
self.assertEqual(
processor.decode(merge, output_offsets=True),
{
"text": (
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
" chest."
),
"offsets": [
{"text": " not worth thinking about.", "timestamp": (22.56, 27.5)},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (27.5, 31.900000000000002),
},
],
},
)
# Merge when the sequence is in the middle of the 1st next sequence
# fmt: off
next_sequences_2 = [
[50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
]
# fmt: on
# {'text': ' of spectators, retrievality is not worth thinking about. His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
merge = _find_timestamp_sequence(
[[previous_sequence, (3000, 0, 0)], [next_sequences_2, (3000, 750, 0)]],
processor.tokenizer,
processor.feature_extractor,
max_source_positions,
)
# fmt: off
self.assertEqual(
merge,
[51492, 406, 3163, 1953, 466, 13, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51959],
)
# fmt: on
self.assertEqual(
processor.decode(merge, output_offsets=True),
{
"text": (
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
" chest."
),
"offsets": [
{
"text": (
" not worth thinking about. His instant panic was followed by a small, sharp blow high on"
" his chest."
),
"timestamp": (22.56, 31.900000000000002),
},
],
},
)
# Merge when the previous sequence is not included in the current sequence
# fmt: off
next_sequences_3 = [[50364, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50584, 50257]]
# fmt: on
# {'text': ' His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
merge = _find_timestamp_sequence(
[[previous_sequence, (3000, 0, 0)], [next_sequences_3, (3000, 750, 0)]],
processor.tokenizer,
processor.feature_extractor,
max_source_positions,
)
# fmt: off
self.assertEqual(
merge,
[51492, 406, 3163, 1953, 466, 13, 51612, 51612, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51832],
)
# fmt: on
self.assertEqual(
processor.decode(merge, output_offsets=True),
{
"text": (
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
" chest."
),
"offsets": [
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (24.96, 29.36),
},
],
},
)
# last case is when the sequence is not in the first next predicted start and end of timestamp
# fmt: off
next_sequences_3 = [
[50364, 2812, 9836, 14783, 390, 51492, 406, 3163, 1953, 466, 13, 50634, 50634, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50934]
]
# fmt: on
merge = _find_timestamp_sequence(
[[previous_sequence, (3000, 0, 0)], [next_sequences_3, (3000, 750, 0)]],
processor.tokenizer,
processor.feature_extractor,
max_source_positions,
)
# fmt: off
self.assertEqual(
merge,
[51492, 406, 3163, 1953, 466, 13, 53112, 53112, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 53332],
)
# fmt: on
self.assertEqual(
processor.decode(merge, output_offsets=True),
{
"text": (
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
" chest."
),
"offsets": [
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (24.96, 29.36),
},
],
},
)
@slow
@require_torch
def test_whisper_timestamp_prediction(self):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
array = np.concatenate(
[ds[40]["audio"]["array"], ds[41]["audio"]["array"], ds[42]["audio"]["array"], ds[43]["audio"]["array"]]
)
pipe = pipeline(
model="openai/whisper-small",
return_timestamps=True,
)
output = pipe(ds[40]["audio"])
self.assertDictEqual(
output,
{
"text": " A man said to the universe, Sir, I exist.",
"chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}],
},
)
pipe = pipeline(
model="openai/whisper-small",
return_timestamps=True,
)
output = pipe(array, chunk_length_s=10)
self.assertDictEqual(
output,
{
"chunks": [
{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)},
{
"text": (
" Sweat covered Brion's body, trickling into the "
"tight-loan cloth that was the only garment he wore, the "
"cut"
),
"timestamp": (5.5, 11.94),
},
{
"text": (
" on his chest still dripping blood, the ache of his "
"overstrained eyes, even the soaring arena around him "
"with"
),
"timestamp": (11.94, 19.6),
},
{
"text": " the thousands of spectators, retrievality is not worth thinking about.",
"timestamp": (19.6, 24.98),
},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (24.98, 30.98),
},
],
"text": (
" A man said to the universe, Sir, I exist. Sweat covered Brion's "
"body, trickling into the tight-loan cloth that was the only garment "
"he wore, the cut on his chest still dripping blood, the ache of his "
"overstrained eyes, even the soaring arena around him with the "
"thousands of spectators, retrievality is not worth thinking about. "
"His instant panic was followed by a small, sharp blow high on his "
"chest."
),
},
)
output = pipe(array)
self.assertDictEqual(
output,
{
"chunks": [
{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)},
{
"text": (
" Sweat covered Brion's body, trickling into the "
"tight-loan cloth that was the only garment"
),
"timestamp": (5.5, 10.18),
},
{"text": " he wore.", "timestamp": (10.18, 11.68)},
{"text": " The cut on his chest still dripping blood.", "timestamp": (11.68, 14.92)},
{"text": " The ache of his overstrained eyes.", "timestamp": (14.92, 17.6)},
{
"text": (
" Even the soaring arena around him with the thousands of spectators were trivialities"
),
"timestamp": (17.6, 22.56),
},
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
],
"text": (
" A man said to the universe, Sir, I exist. Sweat covered Brion's "
"body, trickling into the tight-loan cloth that was the only garment "
"he wore. The cut on his chest still dripping blood. The ache of his "
"overstrained eyes. Even the soaring arena around him with the "
"thousands of spectators were trivialities not worth thinking about."
),
},
)
@require_torch
@slow
def test_torch_speech_encoder_decoder(self):
@@ -724,22 +1015,22 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
def test_chunk_iterator(self):
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
inputs = torch.arange(100).long()
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0))
ratio = 1
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0, ratio))
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
self.assertEqual([o["is_last"] for o in outs], [True])
# two chunks no stride
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0))
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0, ratio))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(50, 0, 0), (50, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 50), (1, 50)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
# two chunks incomplete last
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0))
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0, ratio))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(80, 0, 0), (20, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
@@ -750,7 +1041,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
# This test is specifically crafted to trigger a bug if next chunk
# would be ignored by the fact that all the data would be
# contained in the strided left data.
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5))
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5, ratio))
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
@@ -763,20 +1054,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
"input_values"
]
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10))
ratio = 1
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10, ratio))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10))
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10, ratio))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(80, 0, 10), (50, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 50)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0))
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0, ratio))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)])
@@ -785,7 +1076,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
"input_values"
]
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5))
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5, ratio))
self.assertEqual(len(outs), 5)
self.assertEqual([o["stride"] for o in outs], [(30, 0, 5), (30, 5, 5), (30, 5, 5), (30, 5, 5), (20, 5, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 30), (1, 30), (1, 30), (1, 30), (1, 20)])