From dbaf49203eb3e8be34e51e939c7d4884abdb2d6d Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Wed, 20 Oct 2021 12:22:43 +0300 Subject: [PATCH] [Examples] Use Audio feature in speech classification (#14052) * Update SEW integration test tolerance * Update audio classification * Update test * Remove torchaudio * Add dataset revision * Hub branch naming * Revert dataset revisions * Update datasets --- .../pytorch/audio-classification/README.md | 2 +- .../audio-classification/requirements.txt | 3 +- .../run_audio_classification.py | 38 +++++++------------ .../run_wav2vec2_pretraining_no_trainer.py | 6 +-- .../run_speech_recognition_ctc.py | 4 +- examples/pytorch/test_examples.py | 2 +- 6 files changed, 23 insertions(+), 32 deletions(-) diff --git a/examples/pytorch/audio-classification/README.md b/examples/pytorch/audio-classification/README.md index 4564d40ef0..d6fcd1a186 100644 --- a/examples/pytorch/audio-classification/README.md +++ b/examples/pytorch/audio-classification/README.md @@ -68,7 +68,7 @@ The following command shows how to fine-tune [wav2vec2-base](https://huggingface ```bash python run_audio_classification.py \ --model_name_or_path facebook/wav2vec2-base \ - --dataset_name anton-l/common_language \ + --dataset_name common_language \ --audio_column_name path \ --label_column_name language \ --output_dir wav2vec2-base-lang-id \ diff --git a/examples/pytorch/audio-classification/requirements.txt b/examples/pytorch/audio-classification/requirements.txt index 3c5a1fad9a..6ae3f11c5c 100644 --- a/examples/pytorch/audio-classification/requirements.txt +++ b/examples/pytorch/audio-classification/requirements.txt @@ -1,3 +1,4 @@ -datasets>=1.12.0 +datasets>=1.14.0 +librosa torchaudio torch>=1.6 \ No newline at end of file diff --git a/examples/pytorch/audio-classification/run_audio_classification.py b/examples/pytorch/audio-classification/run_audio_classification.py index 89b421fdf5..a4dd924408 100644 --- a/examples/pytorch/audio-classification/run_audio_classification.py +++ b/examples/pytorch/audio-classification/run_audio_classification.py @@ -22,7 +22,6 @@ from typing import Optional import datasets import numpy as np -import torchaudio from datasets import DatasetDict, load_dataset import transformers @@ -43,19 +42,9 @@ from transformers.utils.versions import require_version logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -check_min_version("4.11.0.dev0") +check_min_version("4.12.0.dev0") -require_version("datasets>=1.12.1", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") - - -def load_audio(path: str, sample_rate: int = 16000): - wav, sr = torchaudio.load(path) - # convert multi-channel audio to mono - wav = wav.mean(0) - # standardize sample rate if it varies in the dataset - resampler = torchaudio.transforms.Resample(sr, sample_rate) - wav = resampler(wav) - return wav +require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000): @@ -100,8 +89,8 @@ class DataTrainingArguments: }, ) audio_column_name: Optional[str] = field( - default="file", - metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'file'"}, + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, ) label_column_name: Optional[str] = field( default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"} @@ -246,13 +235,18 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) + # `datasets` takes care of automatically loading and resampling the audio, + # so we just need to set the correct target sampling rate. + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + def train_transforms(batch): """Apply train_transforms across a batch.""" output_batch = {"input_values": []} - for f in batch[data_args.audio_column_name]: - wav = load_audio(f, sample_rate=feature_extractor.sampling_rate) + for audio in batch[data_args.audio_column_name]: wav = random_subsample( - wav, max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate + audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate ) output_batch["input_values"].append(wav) output_batch["labels"] = [label for label in batch[data_args.label_column_name]] @@ -262,8 +256,8 @@ def main(): def val_transforms(batch): """Apply val_transforms across a batch.""" output_batch = {"input_values": []} - for f in batch[data_args.audio_column_name]: - wav = load_audio(f, sample_rate=feature_extractor.sampling_rate) + for audio in batch[data_args.audio_column_name]: + wav = audio["array"] output_batch["input_values"].append(wav) output_batch["labels"] = [label for label in batch[data_args.label_column_name]] @@ -311,8 +305,6 @@ def main(): model.freeze_feature_extractor() if training_args.do_train: - if "train" not in raw_datasets: - raise ValueError("--do_train requires a train dataset") if data_args.max_train_samples is not None: raw_datasets["train"] = ( raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) @@ -321,8 +313,6 @@ def main(): raw_datasets["train"].set_transform(train_transforms, output_all_columns=False) if training_args.do_eval: - if "eval" not in raw_datasets: - raise ValueError("--do_eval requires a validation dataset") if data_args.max_eval_samples is not None: raw_datasets["eval"] = ( raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) diff --git a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py index e56a3dcb3d..755581b42c 100755 --- a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py +++ b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py @@ -113,7 +113,7 @@ def parse_args(): "--audio_column_name", type=str, default="audio", - help="Column in the dataset that contains speech file path. Defaults to 'file'", + help="Column in the dataset that contains speech file path. Defaults to 'audio'", ) parser.add_argument( "--model_name_or_path", @@ -431,9 +431,9 @@ def main(): # via the `feature_extractor` feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path) - # make sure that dataset decodes audio with correct samlping rate + # make sure that dataset decodes audio with correct sampling rate raw_datasets = raw_datasets.cast_column( - "audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) ) # only normalized-inputs-training is supported diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py index 96881b09ce..e2c2d90957 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py @@ -454,9 +454,9 @@ def main(): # so that we just need to set the correct target sampling rate and normalize the input # via the `feature_extractor` - # make sure that dataset decodes audio with correct samlping rate + # make sure that dataset decodes audio with correct sampling rate raw_datasets = raw_datasets.cast_column( - "audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) ) # derive max & min input length for sample rate & max duration diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index 4ef574f90a..d072a89f09 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -428,7 +428,7 @@ class ExamplesTests(TestCasePlus): --dataset_config_name ks --train_split_name test --eval_split_name test - --audio_column_name file + --audio_column_name audio --label_column_name label --do_train --do_eval