[Speech Examples] Add new audio feature (#14027)
* finish * up * finish all * up
This commit is contained in:
committed by
GitHub
parent
cde0c750af
commit
37c5759cbe
@@ -94,7 +94,7 @@ To pre-train `"large-sized"` Wav2Vec2 model, *e.g.* [facebook/wav2vec2-large-lv6
|
||||
on [librispeech_asr](https://huggingface.co/datasets/librispeech_asr), the following command can be run:
|
||||
|
||||
```bash
|
||||
accelerate launch run_pretrain_no_trainer.py \
|
||||
accelerate launch run_wav2vec2_pretraining_no_trainer.py \
|
||||
--dataset_name=librispeech_asr \
|
||||
--dataset_config_names clean clean other \
|
||||
--dataset_split_names train.100 train.360 train.500 \
|
||||
|
||||
@@ -2,3 +2,4 @@ datasets >= 1.12.0
|
||||
torch >= 1.5
|
||||
torchaudio
|
||||
accelerate >= 0.5.0
|
||||
librosa
|
||||
|
||||
@@ -25,7 +25,6 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torchaudio
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
@@ -113,7 +112,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--audio_column_name",
|
||||
type=str,
|
||||
default="file",
|
||||
default="audio",
|
||||
help="Column in the dataset that contains speech file path. Defaults to 'file'",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -128,6 +127,18 @@ def parse_args():
|
||||
default=None,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_cache_file_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the train cached file name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_cache_file_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the validation cached file name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
type=int,
|
||||
@@ -414,9 +425,17 @@ def main():
|
||||
raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples))
|
||||
raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows))
|
||||
|
||||
# 2. Preprocess audio: load, resample, normalize and truncate
|
||||
# 2. Now we preprocess the datasets including loading the audio, resampling and normalization
|
||||
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
||||
# so that we just need to set the correct target sampling rate and normalize the input
|
||||
# via the `feature_extractor`
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
|
||||
|
||||
# make sure that dataset decodes audio with correct samlping rate
|
||||
raw_datasets = raw_datasets.cast_column(
|
||||
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||
)
|
||||
|
||||
# only normalized-inputs-training is supported
|
||||
if not feature_extractor.do_normalize:
|
||||
raise ValueError(
|
||||
@@ -427,38 +446,40 @@ def main():
|
||||
max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
||||
min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate)
|
||||
|
||||
resampler = None
|
||||
if raw_datasets["train"][args.audio_column_name][0].split(".")[-1] == "mp3":
|
||||
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
|
||||
resampler = torchaudio.transforms.Resample(48_000, feature_extractor.sampling_rate)
|
||||
|
||||
def prepare_dataset(batch):
|
||||
speech_array, sampling_rate = torchaudio.load(batch[args.audio_column_name])
|
||||
speech_array = speech_array.squeeze()
|
||||
sample = batch[args.audio_column_name]
|
||||
|
||||
# if necessary resample audio
|
||||
if resampler is not None:
|
||||
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
|
||||
speech_array = resampler(speech_array)
|
||||
sampling_rate = resampler.new_freq
|
||||
|
||||
speech_array = speech_array.numpy()
|
||||
inputs = feature_extractor(speech_array, sampling_rate=sampling_rate, max_length=max_length, truncation=True)
|
||||
inputs = feature_extractor(
|
||||
sample["array"], sampling_rate=sample["sampling_rate"], max_length=max_length, truncation=True
|
||||
)
|
||||
batch["input_values"] = inputs.input_values[0]
|
||||
batch["input_length"] = len(inputs.input_values[0])
|
||||
|
||||
return batch
|
||||
|
||||
# load via mapped files via path
|
||||
cache_file_names = None
|
||||
if args.train_cache_file_name is not None:
|
||||
cache_file_names = {"train": args.train_cache_file_name, "validation": args.validation_cache_file_name}
|
||||
|
||||
# load audio files into numpy arrays
|
||||
with accelerator.main_process_first():
|
||||
vectorized_datasets = raw_datasets.map(
|
||||
prepare_dataset,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
remove_columns=raw_datasets["train"].column_names,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
)
|
||||
vectorized_datasets = vectorized_datasets.filter(
|
||||
lambda x: len(x["input_values"]) > min_length, load_from_cache_file=not args.overwrite_cache
|
||||
cache_file_names=cache_file_names,
|
||||
)
|
||||
|
||||
if min_length > 0.0:
|
||||
vectorized_datasets = vectorized_datasets.filter(
|
||||
lambda x: x > min_length,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
input_columns=["input_length"],
|
||||
)
|
||||
|
||||
vectorized_datasets = vectorized_datasets.remove_columns("input_length")
|
||||
|
||||
# for large datasets it is advised to run the preprocessing on a
|
||||
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
||||
# be a timeout when running the script in distributed mode.
|
||||
|
||||
Reference in New Issue
Block a user