[Speech Examples] Add new audio feature (#14027)
* finish * up * finish all * up
This commit is contained in:
committed by
GitHub
parent
cde0c750af
commit
37c5759cbe
@@ -13,7 +13,7 @@ streamlit
|
|||||||
elasticsearch
|
elasticsearch
|
||||||
nltk
|
nltk
|
||||||
pandas
|
pandas
|
||||||
datasets >= 1.1.3
|
datasets >= 1.13.3
|
||||||
fire
|
fire
|
||||||
pytest
|
pytest
|
||||||
conllu
|
conllu
|
||||||
@@ -21,3 +21,4 @@ sentencepiece != 0.1.92
|
|||||||
protobuf
|
protobuf
|
||||||
torchvision
|
torchvision
|
||||||
jiwer
|
jiwer
|
||||||
|
librosa
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ To pre-train `"large-sized"` Wav2Vec2 model, *e.g.* [facebook/wav2vec2-large-lv6
|
|||||||
on [librispeech_asr](https://huggingface.co/datasets/librispeech_asr), the following command can be run:
|
on [librispeech_asr](https://huggingface.co/datasets/librispeech_asr), the following command can be run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
accelerate launch run_pretrain_no_trainer.py \
|
accelerate launch run_wav2vec2_pretraining_no_trainer.py \
|
||||||
--dataset_name=librispeech_asr \
|
--dataset_name=librispeech_asr \
|
||||||
--dataset_config_names clean clean other \
|
--dataset_config_names clean clean other \
|
||||||
--dataset_split_names train.100 train.360 train.500 \
|
--dataset_split_names train.100 train.360 train.500 \
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ datasets >= 1.12.0
|
|||||||
torch >= 1.5
|
torch >= 1.5
|
||||||
torchaudio
|
torchaudio
|
||||||
accelerate >= 0.5.0
|
accelerate >= 0.5.0
|
||||||
|
librosa
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from typing import Dict, List, Optional, Union
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
|
||||||
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
@@ -113,7 +112,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--audio_column_name",
|
"--audio_column_name",
|
||||||
type=str,
|
type=str,
|
||||||
default="file",
|
default="audio",
|
||||||
help="Column in the dataset that contains speech file path. Defaults to 'file'",
|
help="Column in the dataset that contains speech file path. Defaults to 'file'",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -128,6 +127,18 @@ def parse_args():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Pretrained config name or path if not the same as model_name",
|
help="Pretrained config name or path if not the same as model_name",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_cache_file_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the train cached file name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--validation_cache_file_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the validation cached file name",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--per_device_train_batch_size",
|
"--per_device_train_batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -414,9 +425,17 @@ def main():
|
|||||||
raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples))
|
raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples))
|
||||||
raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows))
|
raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows))
|
||||||
|
|
||||||
# 2. Preprocess audio: load, resample, normalize and truncate
|
# 2. Now we preprocess the datasets including loading the audio, resampling and normalization
|
||||||
|
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
||||||
|
# so that we just need to set the correct target sampling rate and normalize the input
|
||||||
|
# via the `feature_extractor`
|
||||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
|
||||||
|
|
||||||
|
# make sure that dataset decodes audio with correct samlping rate
|
||||||
|
raw_datasets = raw_datasets.cast_column(
|
||||||
|
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||||
|
)
|
||||||
|
|
||||||
# only normalized-inputs-training is supported
|
# only normalized-inputs-training is supported
|
||||||
if not feature_extractor.do_normalize:
|
if not feature_extractor.do_normalize:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -427,38 +446,40 @@ def main():
|
|||||||
max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
||||||
min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate)
|
min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate)
|
||||||
|
|
||||||
resampler = None
|
|
||||||
if raw_datasets["train"][args.audio_column_name][0].split(".")[-1] == "mp3":
|
|
||||||
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
|
|
||||||
resampler = torchaudio.transforms.Resample(48_000, feature_extractor.sampling_rate)
|
|
||||||
|
|
||||||
def prepare_dataset(batch):
|
def prepare_dataset(batch):
|
||||||
speech_array, sampling_rate = torchaudio.load(batch[args.audio_column_name])
|
sample = batch[args.audio_column_name]
|
||||||
speech_array = speech_array.squeeze()
|
|
||||||
|
|
||||||
# if necessary resample audio
|
inputs = feature_extractor(
|
||||||
if resampler is not None:
|
sample["array"], sampling_rate=sample["sampling_rate"], max_length=max_length, truncation=True
|
||||||
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
|
)
|
||||||
speech_array = resampler(speech_array)
|
|
||||||
sampling_rate = resampler.new_freq
|
|
||||||
|
|
||||||
speech_array = speech_array.numpy()
|
|
||||||
inputs = feature_extractor(speech_array, sampling_rate=sampling_rate, max_length=max_length, truncation=True)
|
|
||||||
batch["input_values"] = inputs.input_values[0]
|
batch["input_values"] = inputs.input_values[0]
|
||||||
|
batch["input_length"] = len(inputs.input_values[0])
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
# load via mapped files via path
|
||||||
|
cache_file_names = None
|
||||||
|
if args.train_cache_file_name is not None:
|
||||||
|
cache_file_names = {"train": args.train_cache_file_name, "validation": args.validation_cache_file_name}
|
||||||
|
|
||||||
# load audio files into numpy arrays
|
# load audio files into numpy arrays
|
||||||
with accelerator.main_process_first():
|
with accelerator.main_process_first():
|
||||||
vectorized_datasets = raw_datasets.map(
|
vectorized_datasets = raw_datasets.map(
|
||||||
prepare_dataset,
|
prepare_dataset,
|
||||||
num_proc=args.preprocessing_num_workers,
|
num_proc=args.preprocessing_num_workers,
|
||||||
remove_columns=raw_datasets["train"].column_names,
|
remove_columns=raw_datasets["train"].column_names,
|
||||||
load_from_cache_file=not args.overwrite_cache,
|
cache_file_names=cache_file_names,
|
||||||
)
|
|
||||||
vectorized_datasets = vectorized_datasets.filter(
|
|
||||||
lambda x: len(x["input_values"]) > min_length, load_from_cache_file=not args.overwrite_cache
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if min_length > 0.0:
|
||||||
|
vectorized_datasets = vectorized_datasets.filter(
|
||||||
|
lambda x: x > min_length,
|
||||||
|
num_proc=args.preprocessing_num_workers,
|
||||||
|
input_columns=["input_length"],
|
||||||
|
)
|
||||||
|
|
||||||
|
vectorized_datasets = vectorized_datasets.remove_columns("input_length")
|
||||||
|
|
||||||
# for large datasets it is advised to run the preprocessing on a
|
# for large datasets it is advised to run the preprocessing on a
|
||||||
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
||||||
# be a timeout when running the script in distributed mode.
|
# be a timeout when running the script in distributed mode.
|
||||||
|
|||||||
@@ -58,7 +58,6 @@ python run_speech_recognition_ctc.py \
|
|||||||
--learning_rate="3e-4" \
|
--learning_rate="3e-4" \
|
||||||
--warmup_steps="500" \
|
--warmup_steps="500" \
|
||||||
--evaluation_strategy="steps" \
|
--evaluation_strategy="steps" \
|
||||||
--audio_column_name="path" \
|
|
||||||
--text_column_name="sentence" \
|
--text_column_name="sentence" \
|
||||||
--save_steps="400" \
|
--save_steps="400" \
|
||||||
--eval_steps="100" \
|
--eval_steps="100" \
|
||||||
@@ -87,7 +86,6 @@ python -m torch.distributed.launch \
|
|||||||
--model_name_or_path="facebook/wav2vec2-large-xlsr-53" \
|
--model_name_or_path="facebook/wav2vec2-large-xlsr-53" \
|
||||||
--dataset_config_name="tr" \
|
--dataset_config_name="tr" \
|
||||||
--output_dir="./wav2vec2-common_voice-tr-demo-dist" \
|
--output_dir="./wav2vec2-common_voice-tr-demo-dist" \
|
||||||
--preprocessing_num_workers="16" \
|
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--num_train_epochs="15" \
|
--num_train_epochs="15" \
|
||||||
--per_device_train_batch_size="4" \
|
--per_device_train_batch_size="4" \
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
datasets >= 1.12.0
|
datasets >= 1.13.3
|
||||||
torch >= 1.5
|
torch >= 1.5
|
||||||
torchaudio
|
torchaudio
|
||||||
|
librosa
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ import sys
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
|
||||||
from datasets import DatasetDict, load_dataset, load_metric
|
from datasets import DatasetDict, load_dataset, load_metric
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -49,8 +49,7 @@ from transformers.utils.versions import require_version
|
|||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.12.0.dev0")
|
check_min_version("4.12.0.dev0")
|
||||||
|
|
||||||
# TODO(Patrick) Bump up as soon as audio features are merged
|
require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||||
require_version("datasets>=1.12.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -179,12 +178,12 @@ class DataTrainingArguments:
|
|||||||
min_duration_in_seconds: Optional[float] = field(
|
min_duration_in_seconds: Optional[float] = field(
|
||||||
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
|
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
|
||||||
)
|
)
|
||||||
only_data_preprocessing: Optional[bool] = field(
|
preprocessing_only: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to only do data preprocessing and skip training. "
|
"help": "Whether to only do data preprocessing and skip training. "
|
||||||
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
||||||
"In this case, one should run the preprocessing in a non-distributed setup with `only_data_preprocessing=True` "
|
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
|
||||||
"so that the cached datasets can consequently be loaded in distributed training"
|
"so that the cached datasets can consequently be loaded in distributed training"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -450,41 +449,30 @@ def main():
|
|||||||
if model_args.freeze_feature_extractor:
|
if model_args.freeze_feature_extractor:
|
||||||
model.freeze_feature_extractor()
|
model.freeze_feature_extractor()
|
||||||
|
|
||||||
# 5. Now we preprocess the datasets which includes loading the audio, resampling and padding
|
# 5. Now we preprocess the datasets including loading the audio, resampling and normalization
|
||||||
|
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
||||||
|
# so that we just need to set the correct target sampling rate and normalize the input
|
||||||
|
# via the `feature_extractor`
|
||||||
|
|
||||||
# The following code should be cleaned up as soon as
|
# make sure that dataset decodes audio with correct samlping rate
|
||||||
# https://github.com/huggingface/datasets/pull/2324 is merged
|
raw_datasets = raw_datasets.cast_column(
|
||||||
|
"audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||||
# Preprocessing the datasets.
|
)
|
||||||
# We need to read the audio files as arrays and tokenize the targets.
|
|
||||||
|
|
||||||
# derive max & min input length for sample rate & max duration
|
# derive max & min input length for sample rate & max duration
|
||||||
max_input_length = data_args.max_duration_in_seconds * processor.feature_extractor.sampling_rate
|
max_input_length = data_args.max_duration_in_seconds * processor.feature_extractor.sampling_rate
|
||||||
min_input_length = data_args.min_duration_in_seconds * processor.feature_extractor.sampling_rate
|
min_input_length = data_args.min_duration_in_seconds * processor.feature_extractor.sampling_rate
|
||||||
|
|
||||||
resampler = None
|
|
||||||
if raw_datasets["train"][data_args.audio_column_name][0].split(".")[-1] == "mp3":
|
|
||||||
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
|
|
||||||
resampler = torchaudio.transforms.Resample(48_000, processor.feature_extractor.sampling_rate)
|
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# We need to read the audio files as arrays and tokenize the targets.
|
# We need to read the audio files as arrays and tokenize the targets.
|
||||||
def prepare_dataset(batch):
|
def prepare_dataset(batch):
|
||||||
# load audio
|
# load audio
|
||||||
speech_array, sampling_rate = torchaudio.load(batch[data_args.audio_column_name])
|
sample = batch[data_args.audio_column_name]
|
||||||
speech_array = speech_array.squeeze()
|
|
||||||
|
|
||||||
# if necessary resample audio
|
|
||||||
if resampler is not None:
|
|
||||||
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
|
|
||||||
speech_array = resampler(speech_array)
|
|
||||||
sampling_rate = resampler.new_freq
|
|
||||||
|
|
||||||
speech_array = speech_array.numpy()
|
|
||||||
|
|
||||||
batch["input_values"] = processor(
|
batch["input_values"] = processor(
|
||||||
speech_array, sampling_rate=sampling_rate, truncate=True, max_length=max_input_length
|
sample["array"], sampling_rate=sample["sampling_rate"], truncate=True, max_length=max_input_length
|
||||||
).input_values[0]
|
).input_values[0]
|
||||||
|
batch["input_length"] = len(batch["input_values"])
|
||||||
|
|
||||||
# Setup the processor for targets
|
# Setup the processor for targets
|
||||||
with processor.as_target_processor():
|
with processor.as_target_processor():
|
||||||
@@ -502,10 +490,13 @@ def main():
|
|||||||
if min_input_length > 0.0:
|
if min_input_length > 0.0:
|
||||||
# filter data that is shorter than min_input_length
|
# filter data that is shorter than min_input_length
|
||||||
vectorized_datasets = vectorized_datasets.filter(
|
vectorized_datasets = vectorized_datasets.filter(
|
||||||
lambda data: len(data["input_values"]) > min_input_length,
|
lambda x: x > min_input_length,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
input_columns=["input_length"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vectorized_datasets = vectorized_datasets.remove_columns("input_length")
|
||||||
|
|
||||||
# 6. Next, we can prepare the training.
|
# 6. Next, we can prepare the training.
|
||||||
# Let's use word error rate (WER) as our evaluation metric,
|
# Let's use word error rate (WER) as our evaluation metric,
|
||||||
# instantiate a data collator and the trainer
|
# instantiate a data collator and the trainer
|
||||||
@@ -513,8 +504,13 @@ def main():
|
|||||||
# Define Metric during training
|
# Define Metric during training
|
||||||
wer_metric = load_metric("wer")
|
wer_metric = load_metric("wer")
|
||||||
|
|
||||||
if data_args.only_data_preprocessing:
|
# for large datasets it is advised to run the preprocessing on a
|
||||||
logger.info("Data preprocessing finished.")
|
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
||||||
|
# be a timeout when running the script in distributed mode.
|
||||||
|
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
|
||||||
|
# cached dataset
|
||||||
|
if data_args.preprocessing_only:
|
||||||
|
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
|
||||||
return
|
return
|
||||||
|
|
||||||
def compute_metrics(pred):
|
def compute_metrics(pred):
|
||||||
|
|||||||
@@ -395,7 +395,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--dataset_config_name clean
|
--dataset_config_name clean
|
||||||
--train_split_name validation
|
--train_split_name validation
|
||||||
--eval_split_name validation
|
--eval_split_name validation
|
||||||
--audio_column_name file
|
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
--learning_rate 1e-4
|
--learning_rate 1e-4
|
||||||
|
|||||||
Reference in New Issue
Block a user