Allow dict input for audio classification pipeline (#23445)
* Allow dict input for audio classification pipeline * make style * Empty commit to trigger CI * Empty commit to trigger CI * check for torchaudio * add pip instructions Co-authored-by: Sylvain <sylvain.gugger@gmail.com> * Update src/transformers/pipelines/audio_classification.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * asr -> audio class * asr -> audio class --------- Co-authored-by: Sylvain <sylvain.gugger@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from typing import Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ..utils import add_end_docstrings, is_torch_available, logging
|
from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
|
||||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
|
|
||||||
@@ -110,12 +110,18 @@ class AudioClassificationPipeline(Pipeline):
|
|||||||
information.
|
information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (`np.ndarray` or `bytes` or `str`):
|
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
||||||
The inputs is either a raw waveform (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
|
The inputs is either :
|
||||||
at the correct sampling rate (no further check will be done) or a `str` that is the filename of the
|
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
|
||||||
audio file, the file will be read at the correct sampling rate to get the waveform using *ffmpeg*. This
|
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
|
||||||
requires *ffmpeg* to be installed on the system. If *inputs* is `bytes` it is supposed to be the
|
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
|
||||||
content of an audio file and is interpreted by *ffmpeg* in the same way.
|
same way.
|
||||||
|
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
|
||||||
|
Raw audio at the correct sampling rate (no further check will be done)
|
||||||
|
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
|
||||||
|
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
|
||||||
|
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
|
||||||
|
`"array"` is used to denote the raw audio waveform.
|
||||||
top_k (`int`, *optional*, defaults to None):
|
top_k (`int`, *optional*, defaults to None):
|
||||||
The number of top labels that will be returned by the pipeline. If the provided number is `None` or
|
The number of top labels that will be returned by the pipeline. If the provided number is `None` or
|
||||||
higher than the number of labels available in the model configuration, it will default to the number of
|
higher than the number of labels available in the model configuration, it will default to the number of
|
||||||
@@ -151,10 +157,42 @@ class AudioClassificationPipeline(Pipeline):
|
|||||||
if isinstance(inputs, bytes):
|
if isinstance(inputs, bytes):
|
||||||
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
|
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
|
||||||
|
|
||||||
|
if isinstance(inputs, dict):
|
||||||
|
# Accepting `"array"` which is the key defined in `datasets` for
|
||||||
|
# better integration
|
||||||
|
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
|
||||||
|
raise ValueError(
|
||||||
|
"When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
|
||||||
|
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
|
||||||
|
"containing the sampling_rate associated with that array"
|
||||||
|
)
|
||||||
|
|
||||||
|
_inputs = inputs.pop("raw", None)
|
||||||
|
if _inputs is None:
|
||||||
|
# Remove path which will not be used from `datasets`.
|
||||||
|
inputs.pop("path", None)
|
||||||
|
_inputs = inputs.pop("array", None)
|
||||||
|
in_sampling_rate = inputs.pop("sampling_rate")
|
||||||
|
inputs = _inputs
|
||||||
|
if in_sampling_rate != self.feature_extractor.sampling_rate:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_torchaudio_available():
|
||||||
|
from torchaudio import functional as F
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
"torchaudio is required to resample audio samples in AudioClassificationPipeline. "
|
||||||
|
"The torchaudio package can be installed through: `pip install torchaudio`."
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = F.resample(
|
||||||
|
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
|
||||||
|
).numpy()
|
||||||
|
|
||||||
if not isinstance(inputs, np.ndarray):
|
if not isinstance(inputs, np.ndarray):
|
||||||
raise ValueError("We expect a numpy ndarray as input")
|
raise ValueError("We expect a numpy ndarray as input")
|
||||||
if len(inputs.shape) != 1:
|
if len(inputs.shape) != 1:
|
||||||
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
|
||||||
|
|
||||||
processed = self.feature_extractor(
|
processed = self.feature_extractor(
|
||||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||||
|
|||||||
@@ -103,6 +103,10 @@ class AudioClassificationPipelineTests(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
|
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
|
||||||
|
|
||||||
|
audio_dict = {"array": np.ones((8000,)), "sampling_rate": audio_classifier.feature_extractor.sampling_rate}
|
||||||
|
output = audio_classifier(audio_dict, top_k=4)
|
||||||
|
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_large_model_pt(self):
|
def test_large_model_pt(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user