Add MMS CTC Fine-Tuning (#24281)
* Add mms ctc fine tuning * make style * More fixes that are needed * make fix-copies * make draft for README * add new file * move to new file * make style * make style * add quick test * make style * make style
This commit is contained in:
committed by
GitHub
parent
0c3fdccf2f
commit
1609a436ec
@@ -63,6 +63,7 @@ if SRC_DIRS is not None:
|
||||
import run_semantic_segmentation
|
||||
import run_seq2seq_qa as run_squad_seq2seq
|
||||
import run_speech_recognition_ctc
|
||||
import run_speech_recognition_ctc_adapter
|
||||
import run_speech_recognition_seq2seq
|
||||
import run_summarization
|
||||
import run_swag
|
||||
@@ -446,6 +447,38 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_speech_recognition_ctc_adapter(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_speech_recognition_ctc_adapter.py
|
||||
--output_dir {tmp_dir}
|
||||
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
|
||||
--dataset_name hf-internal-testing/librispeech_asr_dummy
|
||||
--dataset_config_name clean
|
||||
--train_split_name validation
|
||||
--eval_split_name validation
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 1
|
||||
--remove_unused_columns False
|
||||
--overwrite_output_dir True
|
||||
--preprocessing_num_workers 16
|
||||
--max_steps 10
|
||||
--target_language tur
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_speech_recognition_ctc_adapter.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "./adapter.tur.safetensors")))
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_speech_recognition_seq2seq(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
|
||||
Reference in New Issue
Block a user