Adding the option to return_timestamps on pure CTC ASR models. (#15792)
* Adding the option to return_timestamps on pure CTC ASR models. * Remove `math.prod` which was introduced in Python 3.8 * int are not floats. * Reworking the PR to support "char" vs "word" output. * Fixup! * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Quality. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" Hubert model configuration"""
|
||||
|
||||
import math
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
@@ -253,4 +254,4 @@ class HubertConfig(PretrainedConfig):
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
return functools.reduce(operator.mul, self.conv_stride, 1)
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" SEW model configuration"""
|
||||
|
||||
import math
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
@@ -248,4 +249,4 @@ class SEWConfig(PretrainedConfig):
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
return functools.reduce(operator.mul, self.conv_stride, 1)
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" SEW-D model configuration"""
|
||||
|
||||
import math
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
@@ -284,4 +285,4 @@ class SEWDConfig(PretrainedConfig):
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
return functools.reduce(operator.mul, self.conv_stride, 1)
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" UniSpeech model configuration"""
|
||||
|
||||
import math
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
@@ -294,4 +295,4 @@ class UniSpeechConfig(PretrainedConfig):
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
return functools.reduce(operator.mul, self.conv_stride, 1)
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" UniSpeechSat model configuration"""
|
||||
|
||||
import math
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
@@ -311,4 +312,4 @@ class UniSpeechSatConfig(PretrainedConfig):
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
return functools.reduce(operator.mul, self.conv_stride, 1)
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" Wav2Vec2 model configuration"""
|
||||
|
||||
import math
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
@@ -334,4 +335,4 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
return functools.reduce(operator.mul, self.conv_stride, 1)
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" WavLM model configuration"""
|
||||
|
||||
import math
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
@@ -335,4 +336,4 @@ class WavLMConfig(PretrainedConfig):
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
return functools.reduce(operator.mul, self.conv_stride, 1)
|
||||
|
||||
@@ -165,10 +165,23 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
|
||||
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
||||
inference to provide more context to the model). Only use `stride` with CTC models.
|
||||
return_timestamps (*optional*, `str`):
|
||||
Only available for pure CTC models. If set to `"char"`, the pipeline will return `timestamps` along the
|
||||
text for every character in the text. For instance if you get `[{"text": "h", "timestamps": (0.5,0.6),
|
||||
{"text": "i", "timestamps": (0.7, .9)}]`, then it means the model predicts that the letter "h" was
|
||||
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
|
||||
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
|
||||
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
|
||||
predicts that the word "hi" was pronounces before 0.5 and after 0.9 seconds.
|
||||
|
||||
Return:
|
||||
`Dict`: A dictionary with the following keys:
|
||||
- **text** (`str`) -- The recognized text.
|
||||
- **text** (`str` ) -- The recognized text.
|
||||
- **chunks** (*optional(, `List[Dict]`)
|
||||
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
||||
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
|
||||
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
||||
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
||||
"""
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
@@ -183,6 +196,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
postprocess_params = {}
|
||||
if "decoder_kwargs" in kwargs:
|
||||
postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"]
|
||||
if "return_timestamps" in kwargs:
|
||||
postprocess_params["return_timestamps"] = kwargs["return_timestamps"]
|
||||
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
@@ -323,7 +338,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
extra = model_inputs
|
||||
return {"is_last": is_last, **out, **extra}
|
||||
|
||||
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None):
|
||||
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None):
|
||||
# Optional return types
|
||||
optional = {}
|
||||
|
||||
if return_timestamps and self.type != "ctc":
|
||||
raise ValueError("We cannot return_timestamps yet on non-ctc models !")
|
||||
|
||||
if self.type == "ctc_with_lm":
|
||||
final_logits = []
|
||||
for outputs in model_outputs:
|
||||
@@ -349,6 +370,30 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
tokens = tokens.squeeze(0)
|
||||
text = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
if return_timestamps:
|
||||
if return_timestamps == "char":
|
||||
decoded = self.tokenizer.decode(
|
||||
tokens, skip_special_tokens=skip_special_tokens, output_char_offsets=True
|
||||
)
|
||||
elif return_timestamps == "word":
|
||||
decoded = self.tokenizer.decode(
|
||||
tokens, skip_special_tokens=skip_special_tokens, output_word_offsets=True
|
||||
)
|
||||
chunks = []
|
||||
for item in decoded[f"{return_timestamps}_offsets"]:
|
||||
start = (
|
||||
item["start_offset"]
|
||||
* self.model.config.inputs_to_logits_ratio
|
||||
/ self.feature_extractor.sampling_rate
|
||||
)
|
||||
stop = (
|
||||
item["end_offset"]
|
||||
* self.model.config.inputs_to_logits_ratio
|
||||
/ self.feature_extractor.sampling_rate
|
||||
)
|
||||
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
|
||||
optional["chunks"] = chunks
|
||||
|
||||
extra = defaultdict(list)
|
||||
for output in model_outputs:
|
||||
output.pop("tokens", None)
|
||||
@@ -357,4 +402,4 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if k == "is_last":
|
||||
continue
|
||||
extra[k].append(v)
|
||||
return {"text": text, **extra}
|
||||
return {"text": text, **optional, **extra}
|
||||
|
||||
@@ -82,15 +82,46 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
outputs = speech_recognizer(audio)
|
||||
self.assertEqual(outputs, {"text": ANY(str)})
|
||||
|
||||
# Striding
|
||||
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
|
||||
if speech_recognizer.type == "ctc":
|
||||
outputs = speech_recognizer(audio)
|
||||
self.assertEqual(outputs, {"text": ANY(str)})
|
||||
|
||||
else:
|
||||
# Non CTC models cannot use striding.
|
||||
with self.assertRaises(ValueError):
|
||||
outputs = speech_recognizer(audio)
|
||||
|
||||
# Timestamps
|
||||
audio = np.zeros((34000,))
|
||||
if speech_recognizer.type == "ctc":
|
||||
outputs = speech_recognizer(audio, return_timestamps="char")
|
||||
self.assertIsInstance(outputs["chunks"], list)
|
||||
n = len(outputs["chunks"])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
{
|
||||
"text": ANY(str),
|
||||
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)],
|
||||
},
|
||||
)
|
||||
|
||||
outputs = speech_recognizer(audio, return_timestamps="word")
|
||||
self.assertIsInstance(outputs["chunks"], list)
|
||||
n = len(outputs["chunks"])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
{
|
||||
"text": ANY(str),
|
||||
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)],
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Non CTC models cannot use return_timestamps
|
||||
with self.assertRaises(ValueError):
|
||||
outputs = speech_recognizer(audio, return_timestamps="char")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_pt_defaults(self):
|
||||
@@ -302,6 +333,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
|
||||
|
||||
@@ -322,6 +354,49 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
|
||||
|
||||
@require_torch
|
||||
def test_return_timestamps_ctc_fast(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="hf-internal-testing/tiny-random-wav2vec2",
|
||||
)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
# Take short audio to keep the test readable
|
||||
audio = ds[40]["audio"]["array"][:800]
|
||||
|
||||
output = speech_recognizer(audio, return_timestamps="char")
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"text": "ZBT ZX G",
|
||||
"chunks": [
|
||||
{"text": " ", "timestamp": (0.0, 0.012)},
|
||||
{"text": "Z", "timestamp": (0.012, 0.016)},
|
||||
{"text": "B", "timestamp": (0.016, 0.02)},
|
||||
{"text": "T", "timestamp": (0.02, 0.024)},
|
||||
{"text": " ", "timestamp": (0.024, 0.028)},
|
||||
{"text": "Z", "timestamp": (0.028, 0.032)},
|
||||
{"text": "X", "timestamp": (0.032, 0.036)},
|
||||
{"text": " ", "timestamp": (0.036, 0.04)},
|
||||
{"text": "G", "timestamp": (0.04, 0.044)},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
output = speech_recognizer(audio, return_timestamps="word")
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"text": "ZBT ZX G",
|
||||
"chunks": [
|
||||
{"text": "ZBT", "timestamp": (0.012, 0.024)},
|
||||
{"text": "ZX", "timestamp": (0.028, 0.036)},
|
||||
{"text": "G", "timestamp": (0.04, 0.044)},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_pyctcdecode
|
||||
def test_chunking_fast_with_lm(self):
|
||||
@@ -399,7 +474,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_chunking(self):
|
||||
def test_chunking_and_timestamps(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
@@ -416,11 +491,79 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
audio = ds[40]["audio"]["array"]
|
||||
|
||||
n_repeats = 10
|
||||
audio = np.tile(audio, n_repeats)
|
||||
output = speech_recognizer([audio], batch_size=2)
|
||||
expected_text = "A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats
|
||||
expected = [{"text": expected_text.strip()}]
|
||||
self.assertEqual(output, expected)
|
||||
audio_tiled = np.tile(audio, n_repeats)
|
||||
output = speech_recognizer([audio_tiled], batch_size=2)
|
||||
self.assertEqual(output, [{"text": ("A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats).strip()}])
|
||||
|
||||
output = speech_recognizer(audio, return_timestamps="char")
|
||||
self.assertEqual(audio.shape, (74_400,))
|
||||
self.assertEqual(speech_recognizer.feature_extractor.sampling_rate, 16_000)
|
||||
# The audio is 74_400 / 16_000 = 4.65s long.
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST",
|
||||
"chunks": [
|
||||
{"text": "A", "timestamp": (0.6, 0.62)},
|
||||
{"text": " ", "timestamp": (0.62, 0.66)},
|
||||
{"text": "M", "timestamp": (0.68, 0.7)},
|
||||
{"text": "A", "timestamp": (0.78, 0.8)},
|
||||
{"text": "N", "timestamp": (0.84, 0.86)},
|
||||
{"text": " ", "timestamp": (0.92, 0.98)},
|
||||
{"text": "S", "timestamp": (1.06, 1.08)},
|
||||
{"text": "A", "timestamp": (1.14, 1.16)},
|
||||
{"text": "I", "timestamp": (1.16, 1.18)},
|
||||
{"text": "D", "timestamp": (1.2, 1.24)},
|
||||
{"text": " ", "timestamp": (1.24, 1.28)},
|
||||
{"text": "T", "timestamp": (1.28, 1.32)},
|
||||
{"text": "O", "timestamp": (1.34, 1.36)},
|
||||
{"text": " ", "timestamp": (1.38, 1.42)},
|
||||
{"text": "T", "timestamp": (1.42, 1.44)},
|
||||
{"text": "H", "timestamp": (1.44, 1.46)},
|
||||
{"text": "E", "timestamp": (1.46, 1.5)},
|
||||
{"text": " ", "timestamp": (1.5, 1.56)},
|
||||
{"text": "U", "timestamp": (1.58, 1.62)},
|
||||
{"text": "N", "timestamp": (1.64, 1.68)},
|
||||
{"text": "I", "timestamp": (1.7, 1.72)},
|
||||
{"text": "V", "timestamp": (1.76, 1.78)},
|
||||
{"text": "E", "timestamp": (1.84, 1.86)},
|
||||
{"text": "R", "timestamp": (1.86, 1.9)},
|
||||
{"text": "S", "timestamp": (1.96, 1.98)},
|
||||
{"text": "E", "timestamp": (1.98, 2.02)},
|
||||
{"text": " ", "timestamp": (2.02, 2.06)},
|
||||
{"text": "S", "timestamp": (2.82, 2.86)},
|
||||
{"text": "I", "timestamp": (2.94, 2.96)},
|
||||
{"text": "R", "timestamp": (2.98, 3.02)},
|
||||
{"text": " ", "timestamp": (3.06, 3.12)},
|
||||
{"text": "I", "timestamp": (3.5, 3.52)},
|
||||
{"text": " ", "timestamp": (3.58, 3.6)},
|
||||
{"text": "E", "timestamp": (3.66, 3.68)},
|
||||
{"text": "X", "timestamp": (3.68, 3.7)},
|
||||
{"text": "I", "timestamp": (3.9, 3.92)},
|
||||
{"text": "S", "timestamp": (3.94, 3.96)},
|
||||
{"text": "T", "timestamp": (4.0, 4.02)},
|
||||
{"text": " ", "timestamp": (4.06, 4.1)},
|
||||
],
|
||||
},
|
||||
)
|
||||
output = speech_recognizer(audio, return_timestamps="word")
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST",
|
||||
"chunks": [
|
||||
{"text": "A", "timestamp": (0.6, 0.62)},
|
||||
{"text": "MAN", "timestamp": (0.68, 0.86)},
|
||||
{"text": "SAID", "timestamp": (1.06, 1.24)},
|
||||
{"text": "TO", "timestamp": (1.28, 1.36)},
|
||||
{"text": "THE", "timestamp": (1.42, 1.5)},
|
||||
{"text": "UNIVERSE", "timestamp": (1.58, 2.02)},
|
||||
{"text": "SIR", "timestamp": (2.82, 3.02)},
|
||||
{"text": "I", "timestamp": (3.5, 3.52)},
|
||||
{"text": "EXIST", "timestamp": (3.66, 4.02)},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user