Adding support for microphone streaming within pipeline. (#15046)
* Adding support for `microphone` streaming within pipeline.
- Uses `ffmpeg` to get microphone data.
- Makes sure alignment is made to `size_of_sample`.
- Works by sending `{"raw": ..data.., "stride": (n, left, right),
"partial": bool}`
directly to the pipeline enabling to stream partial results and still
get inference.
- Let's `partial` information flow through the pipeline to enable caller
to get it back and choose to display text or not.
- The striding reconstitution is bound to have errors since CTC does not
keep previous state. Currently most of the errors are we don't know if
there's a space or not between two chunks.
Since we have some left striding info, we could use that during decoding
to choose what to do with those spaces and even extra letters maybe (if
the stride is long enough, it's bound to cover at least a few symbols)
Fixing tests.
Protecting with `require_torch`.
`raw_ctc` support for nicer demo.
Post rebase fixes.
Revamp to split raw_mic_data from it's live chunking.
- Requires a refactor to make everything a bit cleaner.
Automatic resampling.
Small fix.
Small fix.
* Post rebase fix (need to let super handle more logic, reorder args.)
* Update docstrings
* Docstring format.
* Remove print.
* Prevent flow of `input_values`.
* Fixing `stride` too.
* Fixing the PR by removing `raw_ctc`.
* Better docstrings.
* Fixing init.
* Update src/transformers/pipelines/audio_utils.py
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
* Update tests/test_pipelines_automatic_speech_recognition.py
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
* Quality.
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from transformers import (
|
||||
Wav2Vec2ForCTC,
|
||||
)
|
||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
||||
from transformers.pipelines.automatic_speech_recognition import apply_stride, chunk_iter
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
@@ -80,6 +81,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
outputs = speech_recognizer(audio)
|
||||
self.assertEqual(outputs, {"text": ANY(str)})
|
||||
|
||||
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
|
||||
if speech_recognizer.type == "ctc":
|
||||
outputs = speech_recognizer(audio)
|
||||
self.assertEqual(outputs, {"text": ANY(str)})
|
||||
else:
|
||||
# Non CTC models cannot use striding.
|
||||
with self.assertRaises(ValueError):
|
||||
outputs = speech_recognizer(audio)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_pt_defaults(self):
|
||||
@@ -87,7 +97,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="facebook/s2t-small-mustc-en-fr-st",
|
||||
@@ -180,7 +189,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
@slow
|
||||
@require_torch
|
||||
def test_simple_wav2vec2(self):
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
@@ -455,6 +463,28 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
# (85, 100)
|
||||
self.assertEqual(nested_simplify(input_values[:, 80:100]), nested_simplify(outs[4]["input_values"]))
|
||||
|
||||
@require_torch
|
||||
def test_stride(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="hf-internal-testing/tiny-random-wav2vec2",
|
||||
)
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 10)
|
||||
output = speech_recognizer({"raw": waveform, "stride": (0, 0), "sampling_rate": 16_000})
|
||||
self.assertEqual(output, {"text": "OB XB B EB BB B EB B OB X"})
|
||||
|
||||
# 0 effective ids Just take the middle one
|
||||
output = speech_recognizer({"raw": waveform, "stride": (5000, 5000), "sampling_rate": 16_000})
|
||||
self.assertEqual(output, {"text": "B"})
|
||||
|
||||
# Only 1 arange.
|
||||
output = speech_recognizer({"raw": waveform, "stride": (0, 9000), "sampling_rate": 16_000})
|
||||
self.assertEqual(output, {"text": "O"})
|
||||
|
||||
# 2nd arange
|
||||
output = speech_recognizer({"raw": waveform, "stride": (1000, 8000), "sampling_rate": 16_000})
|
||||
self.assertEqual(output, {"text": "B XB"})
|
||||
|
||||
|
||||
@require_torch
|
||||
class ApplyStrideTest(unittest.TestCase):
|
||||
@@ -488,3 +518,79 @@ class ApplyStrideTest(unittest.TestCase):
|
||||
tokens = torch.arange(10).long().reshape((2, 5))
|
||||
apply_stride(tokens, [(100, 20, 0), (60, 0, 20)])
|
||||
self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 6, 6, 6]], tokens.tolist())
|
||||
|
||||
|
||||
def require_ffmpeg(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires FFmpeg.
|
||||
|
||||
These tests are skipped when FFmpeg isn't installed.
|
||||
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
subprocess.check_output(["ffmpeg", "-h"], stderr=subprocess.DEVNULL)
|
||||
return test_case
|
||||
except Exception:
|
||||
return unittest.skip("test requires ffmpeg")(test_case)
|
||||
|
||||
|
||||
def bytes_iter(chunk_size, chunks):
|
||||
for i in range(chunks):
|
||||
yield bytes(range(i * chunk_size, (i + 1) * chunk_size))
|
||||
|
||||
|
||||
@require_ffmpeg
|
||||
class AudioUtilsTest(unittest.TestCase):
|
||||
def test_chunk_bytes_iter_too_big(self):
|
||||
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=2), 10, stride=(0, 0)))
|
||||
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02\x03\x04\x05", "stride": (0, 0)})
|
||||
with self.assertRaises(StopIteration):
|
||||
next(iter_)
|
||||
|
||||
def test_chunk_bytes_iter(self):
|
||||
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=2), 3, stride=(0, 0)))
|
||||
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02", "stride": (0, 0)})
|
||||
self.assertEqual(next(iter_), {"raw": b"\x03\x04\x05", "stride": (0, 0)})
|
||||
with self.assertRaises(StopIteration):
|
||||
next(iter_)
|
||||
|
||||
def test_chunk_bytes_iter_stride(self):
|
||||
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=2), 3, stride=(1, 1)))
|
||||
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02", "stride": (0, 1)})
|
||||
self.assertEqual(next(iter_), {"raw": b"\x01\x02\x03", "stride": (1, 1)})
|
||||
self.assertEqual(next(iter_), {"raw": b"\x02\x03\x04", "stride": (1, 1)})
|
||||
# This is finished, but the chunk_bytes doesn't know it yet.
|
||||
self.assertEqual(next(iter_), {"raw": b"\x03\x04\x05", "stride": (1, 1)})
|
||||
self.assertEqual(next(iter_), {"raw": b"\x04\x05", "stride": (1, 0)})
|
||||
with self.assertRaises(StopIteration):
|
||||
next(iter_)
|
||||
|
||||
def test_chunk_bytes_iter_stride_stream(self):
|
||||
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=2), 5, stride=(1, 1), stream=True))
|
||||
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02", "stride": (0, 0), "partial": True})
|
||||
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02\x03\x04", "stride": (0, 1), "partial": False})
|
||||
self.assertEqual(next(iter_), {"raw": b"\x03\x04\x05", "stride": (1, 0), "partial": False})
|
||||
with self.assertRaises(StopIteration):
|
||||
next(iter_)
|
||||
|
||||
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=3), 5, stride=(1, 1), stream=True))
|
||||
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02", "stride": (0, 0), "partial": True})
|
||||
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02\x03\x04", "stride": (0, 1), "partial": False})
|
||||
self.assertEqual(next(iter_), {"raw": b"\x03\x04\x05\x06\x07", "stride": (1, 1), "partial": False})
|
||||
self.assertEqual(next(iter_), {"raw": b"\x06\x07\x08", "stride": (1, 0), "partial": False})
|
||||
with self.assertRaises(StopIteration):
|
||||
next(iter_)
|
||||
|
||||
iter_ = iter(chunk_bytes_iter(bytes_iter(chunk_size=3, chunks=3), 10, stride=(1, 1), stream=True))
|
||||
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02", "stride": (0, 0), "partial": True})
|
||||
self.assertEqual(next(iter_), {"raw": b"\x00\x01\x02\x03\x04\x05", "stride": (0, 0), "partial": True})
|
||||
self.assertEqual(
|
||||
next(iter_), {"raw": b"\x00\x01\x02\x03\x04\x05\x06\x07\x08", "stride": (0, 0), "partial": True}
|
||||
)
|
||||
self.assertEqual(
|
||||
next(iter_), {"raw": b"\x00\x01\x02\x03\x04\x05\x06\x07\x08", "stride": (0, 0), "partial": False}
|
||||
)
|
||||
with self.assertRaises(StopIteration):
|
||||
next(iter_)
|
||||
|
||||
Reference in New Issue
Block a user