Change the chunk_iter function to handle (#16730)
* Change the chunk_iter function to handle the subtle cases where the last chunk gets ignored since all the data is in the `left_strided` data. We need to remove the right striding on the previous item. * Remove commented line.
This commit is contained in:
@@ -58,9 +58,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
|
||||
chunk = inputs[i : i + chunk_len]
|
||||
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
|
||||
_stride_left = 0 if i == 0 else stride_left
|
||||
is_last = i + step >= inputs_len
|
||||
is_last = i + step + stride_left >= inputs_len
|
||||
_stride_right = 0 if is_last else stride_right
|
||||
|
||||
if chunk.shape[0] > _stride_left:
|
||||
yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed}
|
||||
|
||||
|
||||
@@ -653,6 +653,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||
|
||||
# one chunk since first is also last, because it contains only data
|
||||
# in the right strided part we just mark that part as non stride
|
||||
# This test is specifically crafted to trigger a bug if next chunk
|
||||
# would be ignored by the fact that all the data would be
|
||||
# contained in the strided left data.
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5))
|
||||
self.assertEqual(len(outs), 1)
|
||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [True])
|
||||
|
||||
@require_torch
|
||||
def test_chunk_iterator_stride(self):
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
Reference in New Issue
Block a user