[ASR] Add official ASR CTC example to examples/pytorch/speech-recognition (#13620)
* up * rename * add asr example * add auto feature extractor * some more fixes * correct layerdrop * correct for multi-gpu dist * clean up * refactor * refactor * more fixes * more fixes * clean-up * finish * up * Apply suggestions from code review * fix isort * update * up * add note * apply surajs suggestions * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * isort * small change * Apply suggestions from code review Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> * Apply suggestions from code review Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> * add hubert * Update examples/pytorch/speech-recognition/run_speech_recognition_ctc.py Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
This commit is contained in:
committed by
GitHub
parent
41c186d2a4
commit
4a320f6c9a
@@ -39,6 +39,7 @@ SRC_DIRS = [
|
||||
"summarization",
|
||||
"translation",
|
||||
"image-classification",
|
||||
"speech-recognition",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@@ -52,6 +53,7 @@ if SRC_DIRS is not None:
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_qa as run_squad
|
||||
import run_speech_recognition_ctc
|
||||
import run_summarization
|
||||
import run_swag
|
||||
import run_translation
|
||||
@@ -374,3 +376,37 @@ class ExamplesTests(TestCasePlus):
|
||||
run_image_classification.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
|
||||
def test_run_speech_recognition_ctc(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_speech_recognition_ctc.py
|
||||
--output_dir {tmp_dir}
|
||||
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
|
||||
--dataset_name patrickvonplaten/librispeech_asr_dummy
|
||||
--dataset_config_name clean
|
||||
--train_split_name validation
|
||||
--eval_split_name validation
|
||||
--audio_column_name file
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 1
|
||||
--remove_unused_columns False
|
||||
--overwrite_output_dir True
|
||||
--preprocessing_num_workers 16
|
||||
--max_steps 10
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_speech_recognition_ctc.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
Reference in New Issue
Block a user