[ASR] Add official ASR CTC example to examples/pytorch/speech-recognition (#13620)

* up

* rename

* add asr example

* add auto feature extractor

* some more fixes

* correct layerdrop

* correct for multi-gpu dist

* clean up

* refactor

* refactor

* more fixes

* more fixes

* clean-up

* finish

* up

* Apply suggestions from code review

* fix isort

* update

* up

* add note

* apply surajs suggestions

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* isort

* small change

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* add hubert

* Update examples/pytorch/speech-recognition/run_speech_recognition_ctc.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
This commit is contained in:
Patrick von Platen
2021-09-24 07:01:11 +02:00
committed by GitHub
parent 41c186d2a4
commit 4a320f6c9a
6 changed files with 804 additions and 12 deletions

View File

@@ -39,6 +39,7 @@ SRC_DIRS = [
"summarization",
"translation",
"image-classification",
"speech-recognition",
]
]
sys.path.extend(SRC_DIRS)
@@ -52,6 +53,7 @@ if SRC_DIRS is not None:
import run_mlm
import run_ner
import run_qa as run_squad
import run_speech_recognition_ctc
import run_summarization
import run_swag
import run_translation
@@ -374,3 +376,37 @@ class ExamplesTests(TestCasePlus):
run_image_classification.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
def test_run_speech_recognition_ctc(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_speech_recognition_ctc.py
--output_dir {tmp_dir}
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
--dataset_name patrickvonplaten/librispeech_asr_dummy
--dataset_config_name clean
--train_split_name validation
--eval_split_name validation
--audio_column_name file
--do_train
--do_eval
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 1
--remove_unused_columns False
--overwrite_output_dir True
--preprocessing_num_workers 16
--max_steps 10
--seed 42
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_speech_recognition_ctc.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_loss"], result["train_loss"])