From 3c15fd199090a3d28b3d0e935e63d4d2e5451dcc Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 4 Dec 2023 16:34:13 +0000 Subject: [PATCH] [Seamless v2] Add FE to auto mapping (#27829) --- .../models/auto/feature_extraction_auto.py | 1 + ...st_pipelines_automatic_speech_recognition.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 395875dfa1..457217566e 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -78,6 +78,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( ("regnet", "ConvNextFeatureExtractor"), ("resnet", "ConvNextFeatureExtractor"), ("seamless_m4t", "SeamlessM4TFeatureExtractor"), + ("seamless_m4t_v2", "SeamlessM4TFeatureExtractor"), ("segformer", "SegformerFeatureExtractor"), ("sew", "Wav2Vec2FeatureExtractor"), ("sew-d", "Wav2Vec2FeatureExtractor"), diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 3276042daf..2ccaf71255 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1115,6 +1115,23 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): assert result == EXPECTED_RESULT + @require_torch + @slow + def test_seamless_v2(self): + pipe = pipeline( + "automatic-speech-recognition", + model="facebook/seamless-m4t-v2-large", + device="cuda:0", + ) + + dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + sample = dataset[0]["audio"] + + result = pipe(sample, generate_kwargs={"tgt_lang": "eng"}) + EXPECTED_RESULT = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" + + assert result["text"] == EXPECTED_RESULT + @require_torch @slow def test_chunking_and_timestamps(self):