[Flax Examples] Seq2Seq ASR Fine-Tuning Script (#21764)

* from seq2seq speech

* [Flax] Example script for speech seq2seq

* tests and fixes

* make style

* fix: label padding tokens

* fix: label padding tokens over list

* update ln names for Whisper

* try datasets iter loader

* create readme and append results

* style

* make style

* adjust lr

* use pt dataloader

* make fast

* pin gen max len

* finish

* add pt to requirements for test

* fix pt -> torch

* add accelerate
This commit is contained in:
Sanchit Gandhi
2023-09-29 16:42:58 +01:00
committed by GitHub
parent 391177441b
commit 68e85fc822
5 changed files with 967 additions and 1 deletions

View File

@@ -32,6 +32,7 @@ SRC_DIRS = [
"summarization",
"token-classification",
"question-answering",
"speech-recognition",
]
]
sys.path.extend(SRC_DIRS)
@@ -41,6 +42,7 @@ if SRC_DIRS is not None:
import run_clm_flax
import run_flax_glue
import run_flax_ner
import run_flax_speech_recognition_seq2seq
import run_mlm_flax
import run_qa
import run_summarization_flax
@@ -252,3 +254,32 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)
@slow
def test_run_flax_speech_recognition_seq2seq(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_flax_speech_recognition_seq2seq.py
--model_name_or_path openai/whisper-tiny.en
--dataset_name hf-internal-testing/librispeech_asr_dummy
--dataset_config clean
--train_split_name validation
--eval_split_name validation
--output_dir {tmp_dir}
--overwrite_output_dir
--num_train_epochs=2
--max_train_samples 10
--max_eval_samples 10
--warmup_steps=8
--do_train
--do_eval
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--predict_with_generate
""".split()
with patch.object(sys, "argv", testargs):
run_flax_speech_recognition_seq2seq.main()
result = get_results(tmp_dir, split="eval")
self.assertLessEqual(result["eval_wer"], 0.05)