From a192f61e0825150e54e15fdc451cf37e23532b3f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 12 Apr 2022 18:25:02 +0200 Subject: [PATCH] Change the chunk_iter function to handle (#16730) * Change the chunk_iter function to handle the subtle cases where the last chunk gets ignored since all the data is in the `left_strided` data. We need to remove the right striding on the previous item. * Remove commented line. --- .../pipelines/automatic_speech_recognition.py | 3 +-- .../test_pipelines_automatic_speech_recognition.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 1b92231568..f19e9e8c66 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -58,9 +58,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right): chunk = inputs[i : i + chunk_len] processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") _stride_left = 0 if i == 0 else stride_left - is_last = i + step >= inputs_len + is_last = i + step + stride_left >= inputs_len _stride_right = 0 if is_last else stride_right - if chunk.shape[0] > _stride_left: yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed} diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index e3dab51aab..ec54055d7d 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -653,6 +653,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)]) self.assertEqual([o["is_last"] for o in outs], [False, True]) + # one chunk since first is also last, because it contains only data + # in the right strided part we just mark that part as non stride + # This test is specifically crafted to trigger a bug if next chunk + # would be ignored by the fact that all the data would be + # contained in the strided left data. + outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5)) + self.assertEqual(len(outs), 1) + self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)]) + self.assertEqual([o["is_last"] for o in outs], [True]) + @require_torch def test_chunk_iterator_stride(self): feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")