From 13489248fa8f2cda7503628204f8f43b108797a2 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 24 Feb 2023 09:19:07 +0100 Subject: [PATCH] [Examples] Generalise run audio classification for log-mel models (#21756) * [Examples] Generalise run audio classification for log-mel models * batch feature extractor * make style --- .../run_audio_classification.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/audio-classification/run_audio_classification.py b/examples/pytorch/audio-classification/run_audio_classification.py index fe213c4594..2231e96dc4 100644 --- a/examples/pytorch/audio-classification/run_audio_classification.py +++ b/examples/pytorch/audio-classification/run_audio_classification.py @@ -289,24 +289,27 @@ def main(): data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) ) + model_input_name = feature_extractor.model_input_names[0] + def train_transforms(batch): """Apply train_transforms across a batch.""" - output_batch = {"input_values": []} + subsampled_wavs = [] for audio in batch[data_args.audio_column_name]: wav = random_subsample( audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate ) - output_batch["input_values"].append(wav) + subsampled_wavs.append(wav) + inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate) + output_batch = {model_input_name: inputs.get(model_input_name)} output_batch["labels"] = list(batch[data_args.label_column_name]) return output_batch def val_transforms(batch): """Apply val_transforms across a batch.""" - output_batch = {"input_values": []} - for audio in batch[data_args.audio_column_name]: - wav = audio["array"] - output_batch["input_values"].append(wav) + wavs = [audio["array"] for audio in batch[data_args.audio_column_name]] + inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate) + output_batch = {model_input_name: inputs.get(model_input_name)} output_batch["labels"] = list(batch[data_args.label_column_name]) return output_batch