Add the AudioClassificationPipeline (#13342)
* Add the audio classification pipeline * Remove autoconfig exception * Mark ffmpeg test as slow * Rearrange pipeline tests * Add small test * Replace asserts with ValueError
This commit is contained in:
@@ -23,6 +23,7 @@ There are two categories of pipeline abstractions to be aware about:
|
||||
- The :func:`~transformers.pipeline` which is the most powerful object encapsulating all other pipelines.
|
||||
- The other task-specific pipelines:
|
||||
|
||||
- :class:`~transformers.AudioClassificationPipeline`
|
||||
- :class:`~transformers.AutomaticSpeechRecognitionPipeline`
|
||||
- :class:`~transformers.ConversationalPipeline`
|
||||
- :class:`~transformers.FeatureExtractionPipeline`
|
||||
@@ -30,13 +31,13 @@ There are two categories of pipeline abstractions to be aware about:
|
||||
- :class:`~transformers.ImageClassificationPipeline`
|
||||
- :class:`~transformers.QuestionAnsweringPipeline`
|
||||
- :class:`~transformers.SummarizationPipeline`
|
||||
- :class:`~transformers.TableQuestionAnsweringPipeline`
|
||||
- :class:`~transformers.TextClassificationPipeline`
|
||||
- :class:`~transformers.TextGenerationPipeline`
|
||||
- :class:`~transformers.Text2TextGenerationPipeline`
|
||||
- :class:`~transformers.TokenClassificationPipeline`
|
||||
- :class:`~transformers.TranslationPipeline`
|
||||
- :class:`~transformers.ZeroShotClassificationPipeline`
|
||||
- :class:`~transformers.Text2TextGenerationPipeline`
|
||||
- :class:`~transformers.TableQuestionAnsweringPipeline`
|
||||
|
||||
The pipeline abstraction
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -50,6 +51,13 @@ pipeline but requires an additional argument which is the `task`.
|
||||
The task specific pipelines
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
AudioClassificationPipeline
|
||||
=======================================================================================================================
|
||||
|
||||
.. autoclass:: transformers.AudioClassificationPipeline
|
||||
:special-members: __call__
|
||||
:members:
|
||||
|
||||
AutomaticSpeechRecognitionPipeline
|
||||
=======================================================================================================================
|
||||
|
||||
|
||||
@@ -135,6 +135,13 @@ AutoModelForImageClassification
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForAudioClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AutoModelForAudioClassification
|
||||
:members:
|
||||
|
||||
|
||||
TFAutoModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -277,6 +277,7 @@ _import_structure = {
|
||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||
"models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
|
||||
"pipelines": [
|
||||
"AudioClassificationPipeline",
|
||||
"AutomaticSpeechRecognitionPipeline",
|
||||
"Conversation",
|
||||
"ConversationalPipeline",
|
||||
@@ -527,6 +528,7 @@ if is_torch_available():
|
||||
)
|
||||
_import_structure["models.auto"].extend(
|
||||
[
|
||||
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
@@ -542,6 +544,7 @@ if is_torch_available():
|
||||
"MODEL_MAPPING",
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"AutoModel",
|
||||
"AutoModelForAudioClassification",
|
||||
"AutoModelForCausalLM",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForMaskedLM",
|
||||
@@ -2040,6 +2043,7 @@ if TYPE_CHECKING:
|
||||
|
||||
# Pipelines
|
||||
from .pipelines import (
|
||||
AudioClassificationPipeline,
|
||||
AutomaticSpeechRecognitionPipeline,
|
||||
Conversation,
|
||||
ConversationalPipeline,
|
||||
@@ -2248,6 +2252,7 @@ if TYPE_CHECKING:
|
||||
load_tf_weights_in_albert,
|
||||
)
|
||||
from .models.auto import (
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
@@ -2263,6 +2268,7 @@ if TYPE_CHECKING:
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
AutoModel,
|
||||
AutoModelForAudioClassification,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForMaskedLM,
|
||||
|
||||
@@ -42,6 +42,7 @@ from .file_utils import (
|
||||
is_torch_available,
|
||||
)
|
||||
from .models.auto.modeling_auto import (
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||
@@ -66,6 +67,7 @@ TASK_MAPPING = {
|
||||
"text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||
"table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||
}
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -30,6 +30,7 @@ _import_structure = {
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_auto"] = [
|
||||
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
@@ -45,6 +46,7 @@ if is_torch_available():
|
||||
"MODEL_MAPPING",
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"AutoModel",
|
||||
"AutoModelForAudioClassification",
|
||||
"AutoModelForCausalLM",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForMaskedLM",
|
||||
@@ -119,6 +121,7 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_auto import (
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
@@ -134,6 +137,7 @@ if TYPE_CHECKING:
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
AutoModel,
|
||||
AutoModelForAudioClassification,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForMaskedLM,
|
||||
|
||||
@@ -444,6 +444,14 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Audio Classification mapping
|
||||
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
|
||||
("hubert", "HubertForSequenceClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
||||
@@ -472,6 +480,9 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
|
||||
class AutoModel(_BaseAutoModelClass):
|
||||
@@ -576,6 +587,13 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
|
||||
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
||||
|
||||
|
||||
class AutoModelForAudioClassification(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
|
||||
|
||||
|
||||
class AutoModelWithLMHead(_AutoModelWithLMHead):
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
|
||||
@@ -1616,7 +1616,6 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
# End copy
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[2]
|
||||
|
||||
@@ -27,6 +27,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, Aut
|
||||
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..utils import logging
|
||||
from .audio_classification import AudioClassificationPipeline
|
||||
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
|
||||
from .base import (
|
||||
ArgumentHandler,
|
||||
@@ -86,6 +87,7 @@ if is_torch_available():
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
AutoModel,
|
||||
AutoModelForAudioClassification,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForMaskedLM,
|
||||
@@ -108,6 +110,12 @@ TASK_ALIASES = {
|
||||
"ner": "token-classification",
|
||||
}
|
||||
SUPPORTED_TASKS = {
|
||||
"audio-classification": {
|
||||
"impl": AudioClassificationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}},
|
||||
},
|
||||
"automatic-speech-recognition": {
|
||||
"impl": AutomaticSpeechRecognitionPipeline,
|
||||
"tf": (),
|
||||
|
||||
160
src/transformers/pipelines/audio_classification.py
Normal file
160
src/transformers/pipelines/audio_classification.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import subprocess
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||
from ..file_utils import add_end_docstrings, is_torch_available
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_tf_utils import TFPreTrainedModel
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
|
||||
"""
|
||||
Helper function to read an audio file through ffmpeg.
|
||||
"""
|
||||
ar = f"{sampling_rate}"
|
||||
ac = "1"
|
||||
format_for_conversion = "f32le"
|
||||
ffmpeg_command = [
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
"pipe:0",
|
||||
"-ac",
|
||||
ac,
|
||||
"-ar",
|
||||
ar,
|
||||
"-f",
|
||||
format_for_conversion,
|
||||
"-hide_banner",
|
||||
"-loglevel",
|
||||
"quiet",
|
||||
"pipe:1",
|
||||
]
|
||||
|
||||
try:
|
||||
ffmpeg_process = subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
except FileNotFoundError:
|
||||
raise ValueError("ffmpeg was not found but is required to load audio files from filename")
|
||||
output_stream = ffmpeg_process.communicate(bpayload)
|
||||
out_bytes = output_stream[0]
|
||||
|
||||
audio = np.frombuffer(out_bytes, np.float32)
|
||||
if audio.shape[0] == 0:
|
||||
raise ValueError("Malformed soundfile")
|
||||
return audio
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class AudioClassificationPipeline(Pipeline):
|
||||
"""
|
||||
Audio classification pipeline using any :obj:`AutoModelForAudioClassification`. This pipeline predicts the class of
|
||||
a raw waveform or an audio file. In case of an audio file, ffmpeg should be installed to support multiple audio
|
||||
formats.
|
||||
|
||||
This pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task identifier:
|
||||
:obj:`"audio-classification"`.
|
||||
|
||||
See the list of available models on `huggingface.co/models
|
||||
<https://huggingface.co/models?filter=audio-classification>`__.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union["PreTrainedModel", "TFPreTrainedModel"],
|
||||
feature_extractor: PreTrainedFeatureExtractor,
|
||||
framework: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs)
|
||||
|
||||
if self.framework != "pt":
|
||||
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||
|
||||
self.check_model_type(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[np.ndarray, bytes, str],
|
||||
top_k: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Classify the sequence(s) given as inputs. See the :obj:`~transformers.AutomaticSpeechRecognitionPipeline`
|
||||
documentation for more information.
|
||||
|
||||
Args:
|
||||
inputs (:obj:`np.ndarray` or :obj:`bytes` or :obj:`str`):
|
||||
The inputs is either a raw waveform (:obj:`np.ndarray` of shape (n, ) of type :obj:`np.float32` or
|
||||
:obj:`np.float64`) at the correct sampling rate (no further check will be done) or a :obj:`str` that is
|
||||
the filename of the audio file, the file will be read at the correct sampling rate to get the waveform
|
||||
using `ffmpeg`. This requires `ffmpeg` to be installed on the system. If `inputs` is :obj:`bytes` it is
|
||||
supposed to be the content of an audio file and is interpreted by `ffmpeg` in the same way.
|
||||
top_k (:obj:`int`, `optional`, defaults to None):
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is `None` or
|
||||
higher than the number of labels available in the model configuration, it will default to the number of
|
||||
labels.
|
||||
|
||||
Return:
|
||||
A list of :obj:`dict` with the following keys:
|
||||
|
||||
- **label** (:obj:`str`) -- The label predicted.
|
||||
- **score** (:obj:`float`) -- The corresponding probability.
|
||||
"""
|
||||
if isinstance(inputs, str):
|
||||
with open(inputs, "rb") as f:
|
||||
inputs = f.read()
|
||||
|
||||
if isinstance(inputs, bytes):
|
||||
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
|
||||
|
||||
if not isinstance(inputs, np.ndarray):
|
||||
raise ValueError("We expect a numpy ndarray as input")
|
||||
if len(inputs.shape) != 1:
|
||||
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
|
||||
|
||||
if top_k is None or top_k > self.model.config.num_labels:
|
||||
top_k = self.model.config.num_labels
|
||||
|
||||
processed = self.feature_extractor(
|
||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
processed = self.ensure_tensor_on_device(**processed)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**processed)
|
||||
|
||||
probs = outputs.logits[0].softmax(-1)
|
||||
scores, ids = probs.topk(top_k)
|
||||
|
||||
scores = scores.tolist()
|
||||
ids = ids.tolist()
|
||||
|
||||
labels = [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
||||
|
||||
return labels
|
||||
@@ -307,6 +307,9 @@ def load_tf_weights_in_albert(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_albert, ["torch"])
|
||||
|
||||
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING = None
|
||||
|
||||
|
||||
@@ -358,6 +361,15 @@ class AutoModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForAudioClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
120
tests/test_pipelines_audio_classification.py
Normal file
120
tests/test_pipelines_audio_classification.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, PreTrainedTokenizer
|
||||
from transformers.pipelines import AudioClassificationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_datasets,
|
||||
require_tf,
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
@require_torch
|
||||
class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
|
||||
@require_datasets
|
||||
@slow
|
||||
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||
import datasets
|
||||
|
||||
audio_classifier = AudioClassificationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
|
||||
# test with a raw waveform
|
||||
audio = np.zeros((34000,))
|
||||
output = audio_classifier(audio)
|
||||
# by default a model is initialized with num_labels=2
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
{"score": ANY(float), "label": ANY(str)},
|
||||
{"score": ANY(float), "label": ANY(str)},
|
||||
],
|
||||
)
|
||||
output = audio_classifier(audio, top_k=1)
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
{"score": ANY(float), "label": ANY(str)},
|
||||
],
|
||||
)
|
||||
|
||||
# test with a local file
|
||||
dataset = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = dataset[0]["file"]
|
||||
output = audio_classifier(filename)
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
{"score": ANY(float), "label": ANY(str)},
|
||||
{"score": ANY(float), "label": ANY(str)},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
model = "anton-l/wav2vec2-random-tiny-classifier"
|
||||
tokenizer = PreTrainedTokenizer()
|
||||
audio_classifier = pipeline("audio-classification", model=model, tokenizer=tokenizer)
|
||||
|
||||
audio = np.ones((8000,))
|
||||
output = audio_classifier(audio, top_k=4)
|
||||
self.assertEqual(
|
||||
nested_simplify(output, decimals=4),
|
||||
[
|
||||
{"score": 0.0843, "label": "on"},
|
||||
{"score": 0.0840, "label": "left"},
|
||||
{"score": 0.0837, "label": "off"},
|
||||
{"score": 0.0835, "label": "yes"},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_datasets
|
||||
@slow
|
||||
def test_large_model_pt(self):
|
||||
import datasets
|
||||
|
||||
model = "superb/wav2vec2-base-superb-ks"
|
||||
tokenizer = PreTrainedTokenizer()
|
||||
audio_classifier = pipeline("audio-classification", model=model, tokenizer=tokenizer)
|
||||
dataset = datasets.load_dataset("anton-l/superb_dummy", "ks", split="test")
|
||||
|
||||
audio = np.array(dataset[3]["speech"], dtype=np.float32)
|
||||
output = audio_classifier(audio, top_k=4)
|
||||
self.assertEqual(
|
||||
nested_simplify(output, decimals=4),
|
||||
[
|
||||
{"score": 0.9809, "label": "go"},
|
||||
{"score": 0.0073, "label": "up"},
|
||||
{"score": 0.0064, "label": "_unknown_"},
|
||||
{"score": 0.0015, "label": "down"},
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Audio classification is not implemented for TF")
|
||||
def test_small_model_tf(self):
|
||||
pass
|
||||
@@ -122,8 +122,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"TFRagTokenForGeneration",
|
||||
"Wav2Vec2ForCTC",
|
||||
"HubertForCTC",
|
||||
"Wav2Vec2ForSequenceClassification",
|
||||
"HubertForSequenceClassification",
|
||||
"XLMForQuestionAnswering",
|
||||
"XLNetForQuestionAnswering",
|
||||
"SeparableConv1D",
|
||||
|
||||
Reference in New Issue
Block a user