Adding the option to return_timestamps on pure CTC ASR models. (#15792)

* Adding the option to return_timestamps on pure CTC ASR models.

* Remove `math.prod` which was introduced in Python 3.8

* int are not floats.

* Reworking the PR to support "char" vs "word" output.

* Fixup!

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Quality.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Nicolas Patry
2022-02-25 14:06:45 +01:00
committed by GitHub
parent 7566734d6f
commit ad0d7d1745
9 changed files with 218 additions and 23 deletions

View File

@@ -82,15 +82,46 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
# Striding
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
if speech_recognizer.type == "ctc":
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
else:
# Non CTC models cannot use striding.
with self.assertRaises(ValueError):
outputs = speech_recognizer(audio)
# Timestamps
audio = np.zeros((34000,))
if speech_recognizer.type == "ctc":
outputs = speech_recognizer(audio, return_timestamps="char")
self.assertIsInstance(outputs["chunks"], list)
n = len(outputs["chunks"])
self.assertEqual(
outputs,
{
"text": ANY(str),
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)],
},
)
outputs = speech_recognizer(audio, return_timestamps="word")
self.assertIsInstance(outputs["chunks"], list)
n = len(outputs["chunks"])
self.assertEqual(
outputs,
{
"text": ANY(str),
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)],
},
)
else:
# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError):
outputs = speech_recognizer(audio, return_timestamps="char")
@require_torch
@slow
def test_pt_defaults(self):
@@ -302,6 +333,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
@@ -322,6 +354,49 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
@require_torch
def test_return_timestamps_ctc_fast(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="hf-internal-testing/tiny-random-wav2vec2",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
# Take short audio to keep the test readable
audio = ds[40]["audio"]["array"][:800]
output = speech_recognizer(audio, return_timestamps="char")
self.assertEqual(
output,
{
"text": "ZBT ZX G",
"chunks": [
{"text": " ", "timestamp": (0.0, 0.012)},
{"text": "Z", "timestamp": (0.012, 0.016)},
{"text": "B", "timestamp": (0.016, 0.02)},
{"text": "T", "timestamp": (0.02, 0.024)},
{"text": " ", "timestamp": (0.024, 0.028)},
{"text": "Z", "timestamp": (0.028, 0.032)},
{"text": "X", "timestamp": (0.032, 0.036)},
{"text": " ", "timestamp": (0.036, 0.04)},
{"text": "G", "timestamp": (0.04, 0.044)},
],
},
)
output = speech_recognizer(audio, return_timestamps="word")
self.assertEqual(
output,
{
"text": "ZBT ZX G",
"chunks": [
{"text": "ZBT", "timestamp": (0.012, 0.024)},
{"text": "ZX", "timestamp": (0.028, 0.036)},
{"text": "G", "timestamp": (0.04, 0.044)},
],
},
)
@require_torch
@require_pyctcdecode
def test_chunking_fast_with_lm(self):
@@ -399,7 +474,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@require_torch
@slow
def test_chunking(self):
def test_chunking_and_timestamps(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
@@ -416,11 +491,79 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
audio = ds[40]["audio"]["array"]
n_repeats = 10
audio = np.tile(audio, n_repeats)
output = speech_recognizer([audio], batch_size=2)
expected_text = "A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats
expected = [{"text": expected_text.strip()}]
self.assertEqual(output, expected)
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ("A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats).strip()}])
output = speech_recognizer(audio, return_timestamps="char")
self.assertEqual(audio.shape, (74_400,))
self.assertEqual(speech_recognizer.feature_extractor.sampling_rate, 16_000)
# The audio is 74_400 / 16_000 = 4.65s long.
self.assertEqual(
output,
{
"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST",
"chunks": [
{"text": "A", "timestamp": (0.6, 0.62)},
{"text": " ", "timestamp": (0.62, 0.66)},
{"text": "M", "timestamp": (0.68, 0.7)},
{"text": "A", "timestamp": (0.78, 0.8)},
{"text": "N", "timestamp": (0.84, 0.86)},
{"text": " ", "timestamp": (0.92, 0.98)},
{"text": "S", "timestamp": (1.06, 1.08)},
{"text": "A", "timestamp": (1.14, 1.16)},
{"text": "I", "timestamp": (1.16, 1.18)},
{"text": "D", "timestamp": (1.2, 1.24)},
{"text": " ", "timestamp": (1.24, 1.28)},
{"text": "T", "timestamp": (1.28, 1.32)},
{"text": "O", "timestamp": (1.34, 1.36)},
{"text": " ", "timestamp": (1.38, 1.42)},
{"text": "T", "timestamp": (1.42, 1.44)},
{"text": "H", "timestamp": (1.44, 1.46)},
{"text": "E", "timestamp": (1.46, 1.5)},
{"text": " ", "timestamp": (1.5, 1.56)},
{"text": "U", "timestamp": (1.58, 1.62)},
{"text": "N", "timestamp": (1.64, 1.68)},
{"text": "I", "timestamp": (1.7, 1.72)},
{"text": "V", "timestamp": (1.76, 1.78)},
{"text": "E", "timestamp": (1.84, 1.86)},
{"text": "R", "timestamp": (1.86, 1.9)},
{"text": "S", "timestamp": (1.96, 1.98)},
{"text": "E", "timestamp": (1.98, 2.02)},
{"text": " ", "timestamp": (2.02, 2.06)},
{"text": "S", "timestamp": (2.82, 2.86)},
{"text": "I", "timestamp": (2.94, 2.96)},
{"text": "R", "timestamp": (2.98, 3.02)},
{"text": " ", "timestamp": (3.06, 3.12)},
{"text": "I", "timestamp": (3.5, 3.52)},
{"text": " ", "timestamp": (3.58, 3.6)},
{"text": "E", "timestamp": (3.66, 3.68)},
{"text": "X", "timestamp": (3.68, 3.7)},
{"text": "I", "timestamp": (3.9, 3.92)},
{"text": "S", "timestamp": (3.94, 3.96)},
{"text": "T", "timestamp": (4.0, 4.02)},
{"text": " ", "timestamp": (4.06, 4.1)},
],
},
)
output = speech_recognizer(audio, return_timestamps="word")
self.assertEqual(
output,
{
"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST",
"chunks": [
{"text": "A", "timestamp": (0.6, 0.62)},
{"text": "MAN", "timestamp": (0.68, 0.86)},
{"text": "SAID", "timestamp": (1.06, 1.24)},
{"text": "TO", "timestamp": (1.28, 1.36)},
{"text": "THE", "timestamp": (1.42, 1.5)},
{"text": "UNIVERSE", "timestamp": (1.58, 2.02)},
{"text": "SIR", "timestamp": (2.82, 3.02)},
{"text": "I", "timestamp": (3.5, 3.52)},
{"text": "EXIST", "timestamp": (3.66, 4.02)},
],
},
)
@require_torch
@slow