Change the chunk_iter function to handle (#16730)

* Change the chunk_iter function to handle

the subtle cases where the last chunk gets ignored since all the
data is in the `left_strided` data.

We need to remove the right striding on the previous item.

* Remove commented line.
This commit is contained in:
Nicolas Patry
2022-04-12 18:25:02 +02:00
committed by GitHub
parent cc034f72eb
commit a192f61e08
2 changed files with 12 additions and 2 deletions

View File

@@ -58,9 +58,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
chunk = inputs[i : i + chunk_len]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
_stride_left = 0 if i == 0 else stride_left
is_last = i + step >= inputs_len
is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right
if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed}

View File

@@ -653,6 +653,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
# one chunk since first is also last, because it contains only data
# in the right strided part we just mark that part as non stride
# This test is specifically crafted to trigger a bug if next chunk
# would be ignored by the fact that all the data would be
# contained in the strided left data.
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5))
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
self.assertEqual([o["is_last"] for o in outs], [True])
@require_torch
def test_chunk_iterator_stride(self):
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")