Add Audio Spectogram Transformer (#19981)

* First draft

* Make conversion script work

* Add id2label mapping, run code quality

* Fix copies

* Add first draft of feature extractor

* Update conversion script to use feature extractor

* Make more tests pass

* Add docs

* update input_features to input_values + pad by default to max length

* Fix doc tests

* Add feature extractor tests

* Add proper padding/truncation to feature extractor

* Add support for conversion of all audioset checkpoints

* Improve docs and extend conversion script

* Fix README

* Rename spectogram to spectrogram

* Fix copies

* Add integration test

* Remove dummy conv

* Update to ast

* Update organization

* Fix init

* Rename model to AST

* Add require_torchaudio annotator

* Move import of ASTFeatureExtractor under a is_speech_available

* Fix rebase

* Add pipeline config

* Update name of classifier head

* Rename time_dimension and frequency_dimension for clarity

* Remove print statement

* Fix pipeline test

* Fix pipeline test

* Fix index table

* Fix init

* Fix conversion script

* Rename to ForAudioClassification

* Fix index table

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-11-21 18:58:54 +01:00
committed by GitHub
parent 1e3f17b5ab
commit 4973d2a04c
28 changed files with 2014 additions and 147 deletions

View File

@@ -155,6 +155,12 @@ def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config, feature_
if hasattr(tiny_config, "image_size") and feature_extractor:
feature_extractor = feature_extractor.__class__(size=tiny_config.image_size, crop_size=tiny_config.image_size)
# Audio Spectogram Transformer specific.
if feature_extractor.__class__.__name__ == "ASTFeatureExtractor":
feature_extractor = feature_extractor.__class__(
max_length=tiny_config.max_length, num_mel_bins=tiny_config.num_mel_bins
)
# Speech2TextModel specific.
if hasattr(tiny_config, "input_feat_per_channel") and feature_extractor:
feature_extractor = feature_extractor.__class__(