[Examples] Generalise run audio classification for log-mel models (#21756)
* [Examples] Generalise run audio classification for log-mel models * batch feature extractor * make style
This commit is contained in:
@@ -289,24 +289,27 @@ def main():
|
|||||||
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_input_name = feature_extractor.model_input_names[0]
|
||||||
|
|
||||||
def train_transforms(batch):
|
def train_transforms(batch):
|
||||||
"""Apply train_transforms across a batch."""
|
"""Apply train_transforms across a batch."""
|
||||||
output_batch = {"input_values": []}
|
subsampled_wavs = []
|
||||||
for audio in batch[data_args.audio_column_name]:
|
for audio in batch[data_args.audio_column_name]:
|
||||||
wav = random_subsample(
|
wav = random_subsample(
|
||||||
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
|
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
|
||||||
)
|
)
|
||||||
output_batch["input_values"].append(wav)
|
subsampled_wavs.append(wav)
|
||||||
|
inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
|
||||||
|
output_batch = {model_input_name: inputs.get(model_input_name)}
|
||||||
output_batch["labels"] = list(batch[data_args.label_column_name])
|
output_batch["labels"] = list(batch[data_args.label_column_name])
|
||||||
|
|
||||||
return output_batch
|
return output_batch
|
||||||
|
|
||||||
def val_transforms(batch):
|
def val_transforms(batch):
|
||||||
"""Apply val_transforms across a batch."""
|
"""Apply val_transforms across a batch."""
|
||||||
output_batch = {"input_values": []}
|
wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
|
||||||
for audio in batch[data_args.audio_column_name]:
|
inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
|
||||||
wav = audio["array"]
|
output_batch = {model_input_name: inputs.get(model_input_name)}
|
||||||
output_batch["input_values"].append(wav)
|
|
||||||
output_batch["labels"] = list(batch[data_args.label_column_name])
|
output_batch["labels"] = list(batch[data_args.label_column_name])
|
||||||
|
|
||||||
return output_batch
|
return output_batch
|
||||||
|
|||||||
Reference in New Issue
Block a user