fix: Change is_last chunk calc and add conditional break in chunk_iter (#21612)

* fix: Change is_last chunk calc and add conditional break

* format fix

* account for 0 and full stride_rights, add comment

* add new test

* make style

* update slow whisper asr test timestamps

* use nested_simplify on output and round timestamp to hundreths place
This commit is contained in:
Connor Henderson
2023-02-24 02:30:32 -05:00
committed by GitHub
parent 4446b6b094
commit 279008adc3
2 changed files with 16 additions and 8 deletions

View File

@@ -56,14 +56,15 @@ def rescale_stride(stride, ratio):
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
for i in range(0, inputs_len, step):
# add start and end paddings to the chunk
chunk = inputs[i : i + chunk_len]
for chunk_start_idx in range(0, inputs_len, step):
chunk_end_idx = chunk_start_idx + chunk_len
chunk = inputs[chunk_start_idx:chunk_end_idx]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
if dtype is not None:
processed = processed.to(dtype=dtype)
_stride_left = 0 if i == 0 else stride_left
is_last = i + step + stride_left >= inputs_len
_stride_left = 0 if chunk_start_idx == 0 else stride_left
# all right strides must be full, otherwise it is the last item
is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len
_stride_right = 0 if is_last else stride_right
chunk_len = chunk.shape[0]
@@ -77,6 +78,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
stride = rescale_stride([stride], ratio)[0]
if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": stride, **processed}
if is_last:
break
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):