[Examples] Use Audio feature in speech classification (#14052)
* Update SEW integration test tolerance * Update audio classification * Update test * Remove torchaudio * Add dataset revision * Hub branch naming * Revert dataset revisions * Update datasets
This commit is contained in:
@@ -22,7 +22,6 @@ from typing import Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
from datasets import DatasetDict, load_dataset
|
||||
|
||||
import transformers
|
||||
@@ -43,19 +42,9 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.11.0.dev0")
|
||||
check_min_version("4.12.0.dev0")
|
||||
|
||||
require_version("datasets>=1.12.1", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
||||
def load_audio(path: str, sample_rate: int = 16000):
|
||||
wav, sr = torchaudio.load(path)
|
||||
# convert multi-channel audio to mono
|
||||
wav = wav.mean(0)
|
||||
# standardize sample rate if it varies in the dataset
|
||||
resampler = torchaudio.transforms.Resample(sr, sample_rate)
|
||||
wav = resampler(wav)
|
||||
return wav
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
||||
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
|
||||
@@ -100,8 +89,8 @@ class DataTrainingArguments:
|
||||
},
|
||||
)
|
||||
audio_column_name: Optional[str] = field(
|
||||
default="file",
|
||||
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'file'"},
|
||||
default="audio",
|
||||
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
||||
)
|
||||
label_column_name: Optional[str] = field(
|
||||
default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
|
||||
@@ -246,13 +235,18 @@ def main():
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# `datasets` takes care of automatically loading and resampling the audio,
|
||||
# so we just need to set the correct target sampling rate.
|
||||
raw_datasets = raw_datasets.cast_column(
|
||||
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||
)
|
||||
|
||||
def train_transforms(batch):
|
||||
"""Apply train_transforms across a batch."""
|
||||
output_batch = {"input_values": []}
|
||||
for f in batch[data_args.audio_column_name]:
|
||||
wav = load_audio(f, sample_rate=feature_extractor.sampling_rate)
|
||||
for audio in batch[data_args.audio_column_name]:
|
||||
wav = random_subsample(
|
||||
wav, max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
|
||||
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
|
||||
)
|
||||
output_batch["input_values"].append(wav)
|
||||
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
|
||||
@@ -262,8 +256,8 @@ def main():
|
||||
def val_transforms(batch):
|
||||
"""Apply val_transforms across a batch."""
|
||||
output_batch = {"input_values": []}
|
||||
for f in batch[data_args.audio_column_name]:
|
||||
wav = load_audio(f, sample_rate=feature_extractor.sampling_rate)
|
||||
for audio in batch[data_args.audio_column_name]:
|
||||
wav = audio["array"]
|
||||
output_batch["input_values"].append(wav)
|
||||
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
|
||||
|
||||
@@ -311,8 +305,6 @@ def main():
|
||||
model.freeze_feature_extractor()
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
if data_args.max_train_samples is not None:
|
||||
raw_datasets["train"] = (
|
||||
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
||||
@@ -321,8 +313,6 @@ def main():
|
||||
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
|
||||
|
||||
if training_args.do_eval:
|
||||
if "eval" not in raw_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
if data_args.max_eval_samples is not None:
|
||||
raw_datasets["eval"] = (
|
||||
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
||||
|
||||
Reference in New Issue
Block a user