Add Audio Spectogram Transformer (#19981)

* First draft

* Make conversion script work

* Add id2label mapping, run code quality

* Fix copies

* Add first draft of feature extractor

* Update conversion script to use feature extractor

* Make more tests pass

* Add docs

* update input_features to input_values + pad by default to max length

* Fix doc tests

* Add feature extractor tests

* Add proper padding/truncation to feature extractor

* Add support for conversion of all audioset checkpoints

* Improve docs and extend conversion script

* Fix README

* Rename spectogram to spectrogram

* Fix copies

* Add integration test

* Remove dummy conv

* Update to ast

* Update organization

* Fix init

* Rename model to AST

* Add require_torchaudio annotator

* Move import of ASTFeatureExtractor under a is_speech_available

* Fix rebase

* Add pipeline config

* Update name of classifier head

* Rename time_dimension and frequency_dimension for clarity

* Remove print statement

* Fix pipeline test

* Fix pipeline test

* Fix index table

* Fix init

* Fix conversion script

* Rename to ForAudioClassification

* Fix index table

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-11-21 18:58:54 +01:00
committed by GitHub
parent 1e3f17b5ab
commit 4973d2a04c
28 changed files with 2014 additions and 147 deletions

View File

@@ -91,6 +91,7 @@ if is_torch_available():
from test_module.custom_modeling import CustomModel, NoSuperInitModel
from transformers import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_AUDIO_XVECTOR_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
@@ -223,6 +224,7 @@ class ModelTesterMixin:
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device