[Speech Examples] Add new audio feature (#14027)

* finish

* up

* finish all

* up
This commit is contained in:
Patrick von Platen
2021-10-17 23:01:03 +02:00
committed by GitHub
parent cde0c750af
commit 37c5759cbe
8 changed files with 75 additions and 58 deletions

View File

@@ -58,7 +58,6 @@ python run_speech_recognition_ctc.py \
--learning_rate="3e-4" \
--warmup_steps="500" \
--evaluation_strategy="steps" \
--audio_column_name="path" \
--text_column_name="sentence" \
--save_steps="400" \
--eval_steps="100" \
@@ -87,7 +86,6 @@ python -m torch.distributed.launch \
--model_name_or_path="facebook/wav2vec2-large-xlsr-53" \
--dataset_config_name="tr" \
--output_dir="./wav2vec2-common_voice-tr-demo-dist" \
--preprocessing_num_workers="16" \
--overwrite_output_dir \
--num_train_epochs="15" \
--per_device_train_batch_size="4" \

View File

@@ -1,3 +1,4 @@
datasets >= 1.12.0
datasets >= 1.13.3
torch >= 1.5
torchaudio
librosa

View File

@@ -24,9 +24,9 @@ import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import datasets
import numpy as np
import torch
import torchaudio
from datasets import DatasetDict, load_dataset, load_metric
import transformers
@@ -49,8 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.12.0.dev0")
# TODO(Patrick) Bump up as soon as audio features are merged
require_version("datasets>=1.12.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
logger = logging.getLogger(__name__)
@@ -179,12 +178,12 @@ class DataTrainingArguments:
min_duration_in_seconds: Optional[float] = field(
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
)
only_data_preprocessing: Optional[bool] = field(
preprocessing_only: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to only do data preprocessing and skip training. "
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
"In this case, one should run the preprocessing in a non-distributed setup with `only_data_preprocessing=True` "
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
"so that the cached datasets can consequently be loaded in distributed training"
},
)
@@ -450,41 +449,30 @@ def main():
if model_args.freeze_feature_extractor:
model.freeze_feature_extractor()
# 5. Now we preprocess the datasets which includes loading the audio, resampling and padding
# 5. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
# The following code should be cleaned up as soon as
# https://github.com/huggingface/datasets/pull/2324 is merged
# Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
# make sure that dataset decodes audio with correct samlping rate
raw_datasets = raw_datasets.cast_column(
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
# derive max & min input length for sample rate & max duration
max_input_length = data_args.max_duration_in_seconds * processor.feature_extractor.sampling_rate
min_input_length = data_args.min_duration_in_seconds * processor.feature_extractor.sampling_rate
resampler = None
if raw_datasets["train"][data_args.audio_column_name][0].split(".")[-1] == "mp3":
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
resampler = torchaudio.transforms.Resample(48_000, processor.feature_extractor.sampling_rate)
# Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
def prepare_dataset(batch):
# load audio
speech_array, sampling_rate = torchaudio.load(batch[data_args.audio_column_name])
speech_array = speech_array.squeeze()
# if necessary resample audio
if resampler is not None:
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
speech_array = resampler(speech_array)
sampling_rate = resampler.new_freq
speech_array = speech_array.numpy()
sample = batch[data_args.audio_column_name]
batch["input_values"] = processor(
speech_array, sampling_rate=sampling_rate, truncate=True, max_length=max_input_length
sample["array"], sampling_rate=sample["sampling_rate"], truncate=True, max_length=max_input_length
).input_values[0]
batch["input_length"] = len(batch["input_values"])
# Setup the processor for targets
with processor.as_target_processor():
@@ -502,10 +490,13 @@ def main():
if min_input_length > 0.0:
# filter data that is shorter than min_input_length
vectorized_datasets = vectorized_datasets.filter(
lambda data: len(data["input_values"]) > min_input_length,
lambda x: x > min_input_length,
num_proc=data_args.preprocessing_num_workers,
input_columns=["input_length"],
)
vectorized_datasets = vectorized_datasets.remove_columns("input_length")
# 6. Next, we can prepare the training.
# Let's use word error rate (WER) as our evaluation metric,
# instantiate a data collator and the trainer
@@ -513,8 +504,13 @@ def main():
# Define Metric during training
wer_metric = load_metric("wer")
if data_args.only_data_preprocessing:
logger.info("Data preprocessing finished.")
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only:
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return
def compute_metrics(pred):