Add Speech Seq2Seq Training script (#14792)

* start

* add gradient checkpointing and feature extractor freezing

* Apply suggestions from code review

* up

* up

* up

* correct

* up

* more changes

* up

* up

* up

* remove rst
This commit is contained in:
Patrick von Platen
2021-12-28 10:20:51 +01:00
committed by GitHub
parent 10fd4fa1a6
commit 1c121916f3
13 changed files with 802 additions and 24 deletions

View File

@@ -59,6 +59,7 @@ if SRC_DIRS is not None:
import run_qa as run_squad
import run_seq2seq_qa as run_squad_seq2seq
import run_speech_recognition_ctc
import run_speech_recognition_seq2seq
import run_summarization
import run_swag
import run_translation
@@ -473,6 +474,39 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertLess(result["eval_loss"], result["train_loss"])
def test_run_speech_recognition_seq2seq(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_speech_recognition_seq2seq.py
--output_dir {tmp_dir}
--model_name_or_path hf-internal-testing/tiny-random-speech-encoder-decoder
--dataset_name hf-internal-testing/librispeech_asr_dummy
--dataset_config_name clean
--train_split_name validation
--eval_split_name validation
--do_train
--do_eval
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 4
--remove_unused_columns False
--overwrite_output_dir True
--preprocessing_num_workers 16
--max_steps 10
--seed 42
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_speech_recognition_seq2seq.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_loss"], result["train_loss"])
def test_run_audio_classification(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
@@ -521,10 +555,10 @@ class ExamplesTests(TestCasePlus):
--dataset_config_names clean
--dataset_split_names validation
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--preprocessing_num_workers 16
--max_train_steps 5
--max_train_steps 2
--validation_split_percentage 5
--seed 42
""".split()