[Examples] Use Audio feature in speech classification (#14052)
* Update SEW integration test tolerance * Update audio classification * Update test * Remove torchaudio * Add dataset revision * Hub branch naming * Revert dataset revisions * Update datasets
This commit is contained in:
@@ -68,7 +68,7 @@ The following command shows how to fine-tune [wav2vec2-base](https://huggingface
|
|||||||
```bash
|
```bash
|
||||||
python run_audio_classification.py \
|
python run_audio_classification.py \
|
||||||
--model_name_or_path facebook/wav2vec2-base \
|
--model_name_or_path facebook/wav2vec2-base \
|
||||||
--dataset_name anton-l/common_language \
|
--dataset_name common_language \
|
||||||
--audio_column_name path \
|
--audio_column_name path \
|
||||||
--label_column_name language \
|
--label_column_name language \
|
||||||
--output_dir wav2vec2-base-lang-id \
|
--output_dir wav2vec2-base-lang-id \
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
datasets>=1.12.0
|
datasets>=1.14.0
|
||||||
|
librosa
|
||||||
torchaudio
|
torchaudio
|
||||||
torch>=1.6
|
torch>=1.6
|
||||||
@@ -22,7 +22,6 @@ from typing import Optional
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torchaudio
|
|
||||||
from datasets import DatasetDict, load_dataset
|
from datasets import DatasetDict, load_dataset
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -43,19 +42,9 @@ from transformers.utils.versions import require_version
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.11.0.dev0")
|
check_min_version("4.12.0.dev0")
|
||||||
|
|
||||||
require_version("datasets>=1.12.1", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||||
|
|
||||||
|
|
||||||
def load_audio(path: str, sample_rate: int = 16000):
|
|
||||||
wav, sr = torchaudio.load(path)
|
|
||||||
# convert multi-channel audio to mono
|
|
||||||
wav = wav.mean(0)
|
|
||||||
# standardize sample rate if it varies in the dataset
|
|
||||||
resampler = torchaudio.transforms.Resample(sr, sample_rate)
|
|
||||||
wav = resampler(wav)
|
|
||||||
return wav
|
|
||||||
|
|
||||||
|
|
||||||
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
|
def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
|
||||||
@@ -100,8 +89,8 @@ class DataTrainingArguments:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
audio_column_name: Optional[str] = field(
|
audio_column_name: Optional[str] = field(
|
||||||
default="file",
|
default="audio",
|
||||||
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'file'"},
|
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
||||||
)
|
)
|
||||||
label_column_name: Optional[str] = field(
|
label_column_name: Optional[str] = field(
|
||||||
default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
|
default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
|
||||||
@@ -246,13 +235,18 @@ def main():
|
|||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# `datasets` takes care of automatically loading and resampling the audio,
|
||||||
|
# so we just need to set the correct target sampling rate.
|
||||||
|
raw_datasets = raw_datasets.cast_column(
|
||||||
|
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||||
|
)
|
||||||
|
|
||||||
def train_transforms(batch):
|
def train_transforms(batch):
|
||||||
"""Apply train_transforms across a batch."""
|
"""Apply train_transforms across a batch."""
|
||||||
output_batch = {"input_values": []}
|
output_batch = {"input_values": []}
|
||||||
for f in batch[data_args.audio_column_name]:
|
for audio in batch[data_args.audio_column_name]:
|
||||||
wav = load_audio(f, sample_rate=feature_extractor.sampling_rate)
|
|
||||||
wav = random_subsample(
|
wav = random_subsample(
|
||||||
wav, max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
|
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
|
||||||
)
|
)
|
||||||
output_batch["input_values"].append(wav)
|
output_batch["input_values"].append(wav)
|
||||||
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
|
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
|
||||||
@@ -262,8 +256,8 @@ def main():
|
|||||||
def val_transforms(batch):
|
def val_transforms(batch):
|
||||||
"""Apply val_transforms across a batch."""
|
"""Apply val_transforms across a batch."""
|
||||||
output_batch = {"input_values": []}
|
output_batch = {"input_values": []}
|
||||||
for f in batch[data_args.audio_column_name]:
|
for audio in batch[data_args.audio_column_name]:
|
||||||
wav = load_audio(f, sample_rate=feature_extractor.sampling_rate)
|
wav = audio["array"]
|
||||||
output_batch["input_values"].append(wav)
|
output_batch["input_values"].append(wav)
|
||||||
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
|
output_batch["labels"] = [label for label in batch[data_args.label_column_name]]
|
||||||
|
|
||||||
@@ -311,8 +305,6 @@ def main():
|
|||||||
model.freeze_feature_extractor()
|
model.freeze_feature_extractor()
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if "train" not in raw_datasets:
|
|
||||||
raise ValueError("--do_train requires a train dataset")
|
|
||||||
if data_args.max_train_samples is not None:
|
if data_args.max_train_samples is not None:
|
||||||
raw_datasets["train"] = (
|
raw_datasets["train"] = (
|
||||||
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
||||||
@@ -321,8 +313,6 @@ def main():
|
|||||||
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
|
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
if "eval" not in raw_datasets:
|
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
|
||||||
if data_args.max_eval_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
raw_datasets["eval"] = (
|
raw_datasets["eval"] = (
|
||||||
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ def parse_args():
|
|||||||
"--audio_column_name",
|
"--audio_column_name",
|
||||||
type=str,
|
type=str,
|
||||||
default="audio",
|
default="audio",
|
||||||
help="Column in the dataset that contains speech file path. Defaults to 'file'",
|
help="Column in the dataset that contains speech file path. Defaults to 'audio'",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
@@ -431,9 +431,9 @@ def main():
|
|||||||
# via the `feature_extractor`
|
# via the `feature_extractor`
|
||||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
|
||||||
|
|
||||||
# make sure that dataset decodes audio with correct samlping rate
|
# make sure that dataset decodes audio with correct sampling rate
|
||||||
raw_datasets = raw_datasets.cast_column(
|
raw_datasets = raw_datasets.cast_column(
|
||||||
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||||
)
|
)
|
||||||
|
|
||||||
# only normalized-inputs-training is supported
|
# only normalized-inputs-training is supported
|
||||||
|
|||||||
@@ -454,9 +454,9 @@ def main():
|
|||||||
# so that we just need to set the correct target sampling rate and normalize the input
|
# so that we just need to set the correct target sampling rate and normalize the input
|
||||||
# via the `feature_extractor`
|
# via the `feature_extractor`
|
||||||
|
|
||||||
# make sure that dataset decodes audio with correct samlping rate
|
# make sure that dataset decodes audio with correct sampling rate
|
||||||
raw_datasets = raw_datasets.cast_column(
|
raw_datasets = raw_datasets.cast_column(
|
||||||
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||||
)
|
)
|
||||||
|
|
||||||
# derive max & min input length for sample rate & max duration
|
# derive max & min input length for sample rate & max duration
|
||||||
|
|||||||
@@ -428,7 +428,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--dataset_config_name ks
|
--dataset_config_name ks
|
||||||
--train_split_name test
|
--train_split_name test
|
||||||
--eval_split_name test
|
--eval_split_name test
|
||||||
--audio_column_name file
|
--audio_column_name audio
|
||||||
--label_column_name label
|
--label_column_name label
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
|
|||||||
Reference in New Issue
Block a user