[Flax Examples] Seq2Seq ASR Fine-Tuning Script (#21764)
* from seq2seq speech * [Flax] Example script for speech seq2seq * tests and fixes * make style * fix: label padding tokens * fix: label padding tokens over list * update ln names for Whisper * try datasets iter loader * create readme and append results * style * make style * adjust lr * use pt dataloader * make fast * pin gen max len * finish * add pt to requirements for test * fix pt -> torch * add accelerate
This commit is contained in:
@@ -32,6 +32,7 @@ SRC_DIRS = [
|
||||
"summarization",
|
||||
"token-classification",
|
||||
"question-answering",
|
||||
"speech-recognition",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@@ -41,6 +42,7 @@ if SRC_DIRS is not None:
|
||||
import run_clm_flax
|
||||
import run_flax_glue
|
||||
import run_flax_ner
|
||||
import run_flax_speech_recognition_seq2seq
|
||||
import run_mlm_flax
|
||||
import run_qa
|
||||
import run_summarization_flax
|
||||
@@ -252,3 +254,32 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_f1"], 30)
|
||||
self.assertGreaterEqual(result["eval_exact"], 30)
|
||||
|
||||
@slow
|
||||
def test_run_flax_speech_recognition_seq2seq(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_flax_speech_recognition_seq2seq.py
|
||||
--model_name_or_path openai/whisper-tiny.en
|
||||
--dataset_name hf-internal-testing/librispeech_asr_dummy
|
||||
--dataset_config clean
|
||||
--train_split_name validation
|
||||
--eval_split_name validation
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--num_train_epochs=2
|
||||
--max_train_samples 10
|
||||
--max_eval_samples 10
|
||||
--warmup_steps=8
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate=2e-4
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
--predict_with_generate
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_flax_speech_recognition_seq2seq.main()
|
||||
result = get_results(tmp_dir, split="eval")
|
||||
self.assertLessEqual(result["eval_wer"], 0.05)
|
||||
|
||||
Reference in New Issue
Block a user