Add Speech Seq2Seq Training script (#14792)
* start * add gradient checkpointing and feature extractor freezing * Apply suggestions from code review * up * up * up * correct * up * more changes * up * up * up * remove rst
This commit is contained in:
committed by
GitHub
parent
10fd4fa1a6
commit
1c121916f3
@@ -59,6 +59,7 @@ if SRC_DIRS is not None:
|
||||
import run_qa as run_squad
|
||||
import run_seq2seq_qa as run_squad_seq2seq
|
||||
import run_speech_recognition_ctc
|
||||
import run_speech_recognition_seq2seq
|
||||
import run_summarization
|
||||
import run_swag
|
||||
import run_translation
|
||||
@@ -473,6 +474,39 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_speech_recognition_seq2seq(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_speech_recognition_seq2seq.py
|
||||
--output_dir {tmp_dir}
|
||||
--model_name_or_path hf-internal-testing/tiny-random-speech-encoder-decoder
|
||||
--dataset_name hf-internal-testing/librispeech_asr_dummy
|
||||
--dataset_config_name clean
|
||||
--train_split_name validation
|
||||
--eval_split_name validation
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 4
|
||||
--remove_unused_columns False
|
||||
--overwrite_output_dir True
|
||||
--preprocessing_num_workers 16
|
||||
--max_steps 10
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_speech_recognition_seq2seq.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_audio_classification(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
@@ -521,10 +555,10 @@ class ExamplesTests(TestCasePlus):
|
||||
--dataset_config_names clean
|
||||
--dataset_split_names validation
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 2
|
||||
--per_device_train_batch_size 4
|
||||
--per_device_eval_batch_size 4
|
||||
--preprocessing_num_workers 16
|
||||
--max_train_steps 5
|
||||
--max_train_steps 2
|
||||
--validation_split_percentage 5
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
Reference in New Issue
Block a user