fix: Change is_last chunk calc and add conditional break in chunk_iter (#21612)

* fix: Change is_last chunk calc and add conditional break

* format fix

* account for 0 and full stride_rights, add comment

* add new test

* make style

* update slow whisper asr test timestamps

* use nested_simplify on output and round timestamp to hundreths place
This commit is contained in:
Connor Henderson
2023-02-24 02:30:32 -05:00
committed by GitHub
parent 4446b6b094
commit 279008adc3
2 changed files with 16 additions and 8 deletions

View File

@@ -526,7 +526,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = pipe(array, chunk_length_s=10)
self.assertDictEqual(
output,
nested_simplify(output),
{
"chunks": [
{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)},
@@ -548,11 +548,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
},
{
"text": " the thousands of spectators, retrievality is not worth thinking about.",
"timestamp": (19.6, 24.98),
"timestamp": (19.6, 26.66),
},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (24.98, 30.98),
"timestamp": (26.66, 31.06),
},
],
"text": (
@@ -1110,6 +1110,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)])
outs = list(chunk_iter(inputs, feature_extractor, 36, 6, 6, ratio))
self.assertEqual(len(outs), 4)
self.assertEqual([o["stride"] for o in outs], [(36, 0, 6), (36, 6, 6), (36, 6, 6), (28, 6, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 36), (1, 36), (1, 36), (1, 28)])
inputs = torch.LongTensor([i % 2 for i in range(100)])
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
"input_values"