Adding chunking for whisper (all seq2seq actually). Very crude matching algorithm. (#20104)
* Very crude matching algorithm. * Fixing tests. * Removing comments * Adding warning + fix short matches. * Cleanup tests. * Quality. * Less noisy. * Fixup.
This commit is contained in:
@@ -30,7 +30,7 @@ if is_torch_available():
|
|||||||
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
||||||
|
|
||||||
|
|
||||||
def rescale_stride(tokens_or_logits, stride, ratio):
|
def rescale_stride(stride, ratio):
|
||||||
"""
|
"""
|
||||||
Rescales the stride values from audio space to tokens/logits space.
|
Rescales the stride values from audio space to tokens/logits space.
|
||||||
|
|
||||||
@@ -60,8 +60,43 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
|
|||||||
_stride_left = 0 if i == 0 else stride_left
|
_stride_left = 0 if i == 0 else stride_left
|
||||||
is_last = i + step + stride_left >= inputs_len
|
is_last = i + step + stride_left >= inputs_len
|
||||||
_stride_right = 0 if is_last else stride_right
|
_stride_right = 0 if is_last else stride_right
|
||||||
|
|
||||||
|
if "input_features" in processed:
|
||||||
|
processed_len = processed["input_features"].shape[-1]
|
||||||
|
elif "input_values" in processed:
|
||||||
|
processed_len = processed["input_values"].shape[-1]
|
||||||
|
chunk_len = chunk.shape[0]
|
||||||
|
stride = (chunk_len, _stride_left, _stride_right)
|
||||||
|
if processed_len != chunk.shape[-1]:
|
||||||
|
ratio = processed_len / chunk_len
|
||||||
|
stride = rescale_stride([stride], ratio)[0]
|
||||||
if chunk.shape[0] > _stride_left:
|
if chunk.shape[0] > _stride_left:
|
||||||
yield {"is_last": is_last, "stride": (chunk.shape[0], _stride_left, _stride_right), **processed}
|
yield {"is_last": is_last, "stride": stride, **processed}
|
||||||
|
|
||||||
|
|
||||||
|
def _find_longest_common_sequence(sequences, tokenizer):
|
||||||
|
# TODO Use a faster algorithm this can probably be done in O(n)
|
||||||
|
# using suffix array.
|
||||||
|
# It might be tedious to do because of fault tolerance.
|
||||||
|
# We actually have a really good property which is that the total sequence
|
||||||
|
# MUST be those subsequences in order.
|
||||||
|
# Also the algorithm should be more tolerant to errors.
|
||||||
|
sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids]
|
||||||
|
for new_seq in sequences[1:]:
|
||||||
|
new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids]
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
max_ = 0.0
|
||||||
|
for i in range(1, len(new_sequence) + 1):
|
||||||
|
# epsilon to favor long perfect matches
|
||||||
|
eps = i / 10000.0
|
||||||
|
matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i]))
|
||||||
|
matching = matches / i + eps
|
||||||
|
if matches > 1 and matching > max_:
|
||||||
|
index = i
|
||||||
|
max_ = matching
|
||||||
|
sequence.extend(new_sequence[index:])
|
||||||
|
return np.array(sequence)
|
||||||
|
|
||||||
|
|
||||||
class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||||
@@ -188,6 +223,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
|
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
|
||||||
if "stride_length_s" in kwargs:
|
if "stride_length_s" in kwargs:
|
||||||
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
||||||
|
if "ignore_warning" in kwargs:
|
||||||
|
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
|
||||||
|
|
||||||
postprocess_params = {}
|
postprocess_params = {}
|
||||||
if "decoder_kwargs" in kwargs:
|
if "decoder_kwargs" in kwargs:
|
||||||
@@ -197,7 +234,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
return preprocess_params, {}, postprocess_params
|
return preprocess_params, {}, postprocess_params
|
||||||
|
|
||||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
|
||||||
if isinstance(inputs, str):
|
if isinstance(inputs, str):
|
||||||
with open(inputs, "rb") as f:
|
with open(inputs, "rb") as f:
|
||||||
inputs = f.read()
|
inputs = f.read()
|
||||||
@@ -249,10 +286,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
||||||
|
|
||||||
if chunk_length_s:
|
if chunk_length_s:
|
||||||
if self.type not in {"ctc", "ctc_with_lm"}:
|
if self.type == "seq2seq" and not ignore_warning:
|
||||||
raise ValueError(
|
logger.warning(
|
||||||
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
|
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
|
||||||
|
" be entirely accurate and will have caveats. More information:"
|
||||||
|
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
|
||||||
|
" ignore_warning=True)"
|
||||||
)
|
)
|
||||||
|
self._preprocess_params["ignore_warning"] = True
|
||||||
if stride_length_s is None:
|
if stride_length_s is None:
|
||||||
stride_length_s = chunk_length_s / 6
|
stride_length_s = chunk_length_s / 6
|
||||||
|
|
||||||
@@ -262,7 +303,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# XXX: Carefuly, this variable will not exist in `seq2seq` setting.
|
# XXX: Carefuly, this variable will not exist in `seq2seq` setting.
|
||||||
# Currently chunking is not possible at this level for `seq2seq` so
|
# Currently chunking is not possible at this level for `seq2seq` so
|
||||||
# it's ok.
|
# it's ok.
|
||||||
align_to = self.model.config.inputs_to_logits_ratio
|
align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
|
||||||
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
|
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
|
||||||
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
|
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
|
||||||
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
|
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
|
||||||
@@ -329,9 +370,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# the pieces are to be concatenated.
|
# the pieces are to be concatenated.
|
||||||
ratio = 1 / self.model.config.inputs_to_logits_ratio
|
ratio = 1 / self.model.config.inputs_to_logits_ratio
|
||||||
if isinstance(stride, tuple):
|
if isinstance(stride, tuple):
|
||||||
out["stride"] = rescale_stride(logits, [stride], ratio)[0]
|
out["stride"] = rescale_stride([stride], ratio)[0]
|
||||||
else:
|
else:
|
||||||
out["stride"] = rescale_stride(logits, stride, ratio)
|
out["stride"] = rescale_stride(stride, ratio)
|
||||||
# Leftover
|
# Leftover
|
||||||
extra = model_inputs
|
extra = model_inputs
|
||||||
return {"is_last": is_last, **out, **extra}
|
return {"is_last": is_last, **out, **extra}
|
||||||
@@ -347,10 +388,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
final_items = []
|
final_items = []
|
||||||
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
||||||
|
stride = None
|
||||||
for outputs in model_outputs:
|
for outputs in model_outputs:
|
||||||
items = outputs[key].numpy()
|
items = outputs[key].numpy()
|
||||||
stride = outputs.pop("stride", None)
|
stride = outputs.pop("stride", None)
|
||||||
if stride is not None:
|
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
|
||||||
total_n, left, right = stride
|
total_n, left, right = stride
|
||||||
# Total_n might be < logits.shape[1]
|
# Total_n might be < logits.shape[1]
|
||||||
# because of padding, that's why
|
# because of padding, that's why
|
||||||
@@ -359,6 +401,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
right_n = total_n - right
|
right_n = total_n - right
|
||||||
items = items[:, left:right_n]
|
items = items[:, left:right_n]
|
||||||
final_items.append(items)
|
final_items.append(items)
|
||||||
|
if stride and self.type == "seq2seq":
|
||||||
|
items = _find_longest_common_sequence(final_items, self.tokenizer)
|
||||||
|
else:
|
||||||
items = np.concatenate(final_items, axis=1)
|
items = np.concatenate(final_items, axis=1)
|
||||||
items = items.squeeze(0)
|
items = items.squeeze(0)
|
||||||
if self.type == "ctc_with_lm":
|
if self.type == "ctc_with_lm":
|
||||||
|
|||||||
@@ -144,12 +144,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||||
output = speech_recognizer(waveform)
|
output = speech_recognizer(waveform)
|
||||||
self.assertEqual(output, {"text": "(Applaudissements)"})
|
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||||
with self.assertRaises(ValueError) as v:
|
output = speech_recognizer(waveform, chunk_length_s=10)
|
||||||
_ = speech_recognizer(waveform, chunk_length_s=10)
|
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||||
self.assertEqual(
|
|
||||||
str(v.exception),
|
|
||||||
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Non CTC models cannot use return_timestamps
|
# Non CTC models cannot use return_timestamps
|
||||||
with self.assertRaises(ValueError) as v:
|
with self.assertRaises(ValueError) as v:
|
||||||
@@ -261,6 +257,22 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
output = speech_recognizer(filename)
|
output = speech_recognizer(filename)
|
||||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_torch_whisper(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="openai/whisper-tiny",
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||||
|
filename = ds[40]["file"]
|
||||||
|
output = speech_recognizer(filename)
|
||||||
|
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
|
||||||
|
|
||||||
|
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
|
||||||
|
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_torch_speech_encoder_decoder(self):
|
def test_torch_speech_encoder_decoder(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user