[ci-daily] Fix pipeline tests (#21257)
* use streaming dataset * fix whisper's test * add rescale argument to chunk_iter
This commit is contained in:
@@ -56,7 +56,7 @@ def rescale_stride(stride, ratio):
|
|||||||
return new_strides
|
return new_strides
|
||||||
|
|
||||||
|
|
||||||
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ratio, dtype=None):
|
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None):
|
||||||
inputs_len = inputs.shape[0]
|
inputs_len = inputs.shape[0]
|
||||||
step = chunk_len - stride_left - stride_right
|
step = chunk_len - stride_left - stride_right
|
||||||
for i in range(0, inputs_len, step):
|
for i in range(0, inputs_len, step):
|
||||||
@@ -68,9 +68,15 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
|
|||||||
_stride_left = 0 if i == 0 else stride_left
|
_stride_left = 0 if i == 0 else stride_left
|
||||||
is_last = i + step + stride_left >= inputs_len
|
is_last = i + step + stride_left >= inputs_len
|
||||||
_stride_right = 0 if is_last else stride_right
|
_stride_right = 0 if is_last else stride_right
|
||||||
|
|
||||||
chunk_len = chunk.shape[0]
|
chunk_len = chunk.shape[0]
|
||||||
stride = (chunk_len, _stride_left, _stride_right)
|
stride = (chunk_len, _stride_left, _stride_right)
|
||||||
if ratio != 1:
|
if "input_features" in processed:
|
||||||
|
processed_len = processed["input_features"].shape[-1]
|
||||||
|
elif "input_values" in processed:
|
||||||
|
processed_len = processed["input_values"].shape[-1]
|
||||||
|
if processed_len != chunk.shape[-1] and rescale:
|
||||||
|
ratio = processed_len / chunk_len
|
||||||
stride = rescale_stride([stride], ratio)[0]
|
stride = rescale_stride([stride], ratio)[0]
|
||||||
if chunk.shape[0] > _stride_left:
|
if chunk.shape[0] > _stride_left:
|
||||||
yield {"is_last": is_last, "stride": stride, **processed}
|
yield {"is_last": is_last, "stride": stride, **processed}
|
||||||
@@ -101,10 +107,10 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source
|
|||||||
sequence = sequence[begin_idx:]
|
sequence = sequence[begin_idx:]
|
||||||
|
|
||||||
timestamp_tokens = sequence >= timestamp_begin
|
timestamp_tokens = sequence >= timestamp_begin
|
||||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
if seq_idx != 0 and sum(timestamp_tokens) > 0:
|
||||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||||
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||||
if seq_idx != 0:
|
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
||||||
time -= stride_left + stride_right
|
time -= stride_left + stride_right
|
||||||
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
||||||
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
||||||
@@ -400,13 +406,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
" only 1 version"
|
" only 1 version"
|
||||||
)
|
)
|
||||||
forward_params["generate_kwargs"].update(generate_kwargs)
|
forward_params["generate_kwargs"].update(generate_kwargs)
|
||||||
if return_timestamps is not None:
|
|
||||||
forward_params["generate_kwargs"]["return_timestamps"] = return_timestamps
|
|
||||||
|
|
||||||
postprocess_params = {}
|
postprocess_params = {}
|
||||||
if decoder_kwargs is not None:
|
if decoder_kwargs is not None:
|
||||||
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
||||||
if return_timestamps is not None:
|
if return_timestamps is not None:
|
||||||
|
forward_params["return_timestamps"] = return_timestamps
|
||||||
postprocess_params["return_timestamps"] = return_timestamps
|
postprocess_params["return_timestamps"] = return_timestamps
|
||||||
if self.model.config.model_type == "whisper":
|
if self.model.config.model_type == "whisper":
|
||||||
# Whisper is highly specific, if we want timestamps, we need to
|
# Whisper is highly specific, if we want timestamps, we need to
|
||||||
@@ -502,9 +507,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
if chunk_len < stride_left + stride_right:
|
if chunk_len < stride_left + stride_right:
|
||||||
raise ValueError("Chunk length must be superior to stride length")
|
raise ValueError("Chunk length must be superior to stride length")
|
||||||
|
|
||||||
|
rescale = self.type != "seq2seq_whisper"
|
||||||
# make sure that
|
# make sure that
|
||||||
for item in chunk_iter(
|
for item in chunk_iter(
|
||||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, align_to, self.torch_dtype
|
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype
|
||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
else:
|
else:
|
||||||
@@ -520,12 +526,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
processed["stride"] = stride
|
processed["stride"] = stride
|
||||||
yield {"is_last": True, **processed, **extra}
|
yield {"is_last": True, **processed, **extra}
|
||||||
|
|
||||||
def _forward(self, model_inputs, generate_kwargs=None):
|
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
|
||||||
if generate_kwargs is None:
|
if generate_kwargs is None:
|
||||||
generate_kwargs = {}
|
generate_kwargs = {}
|
||||||
|
|
||||||
is_last = model_inputs.pop("is_last")
|
is_last = model_inputs.pop("is_last")
|
||||||
return_timestamps = generate_kwargs.pop("return_timestamps", False)
|
|
||||||
|
|
||||||
if self.type == "seq2seq":
|
if self.type == "seq2seq":
|
||||||
encoder = self.model.get_encoder()
|
encoder = self.model.get_encoder()
|
||||||
@@ -635,9 +640,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# Simply cast from pyctcdecode format to wav2vec2 format to leverage
|
# Simply cast from pyctcdecode format to wav2vec2 format to leverage
|
||||||
# pre-existing code later
|
# pre-existing code later
|
||||||
chunk_offset = beams[0][2]
|
chunk_offset = beams[0][2]
|
||||||
word_offsets = []
|
offsets = []
|
||||||
for word, (start_offset, end_offset) in chunk_offset:
|
for word, (start_offset, end_offset) in chunk_offset:
|
||||||
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
||||||
else:
|
else:
|
||||||
skip_special_tokens = self.type != "ctc"
|
skip_special_tokens = self.type != "ctc"
|
||||||
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
||||||
|
|||||||
@@ -201,8 +201,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
@require_torch
|
@require_torch
|
||||||
@require_pyctcdecode
|
@require_pyctcdecode
|
||||||
def test_large_model_pt_with_lm(self):
|
def test_large_model_pt_with_lm(self):
|
||||||
dataset = load_dataset("Narsil/asr_dummy")
|
dataset = load_dataset("Narsil/asr_dummy", streaming=True)
|
||||||
filename = dataset["test"][3]["file"]
|
third_item = next(iter(dataset["test"].skip(3)))
|
||||||
|
filename = third_item["file"]
|
||||||
|
|
||||||
speech_recognizer = pipeline(
|
speech_recognizer = pipeline(
|
||||||
task="automatic-speech-recognition",
|
task="automatic-speech-recognition",
|
||||||
|
|||||||
Reference in New Issue
Block a user