[Speech Examples] Add pytorch speech pretraining (#13877)

* adapt wav2vec2

* add example

* add files

* adapt

* remove bogus file

* Apply suggestions from code review

* adapt files more

* upload changes

* del old files

* up

* up

* up

* up

* up

* correct gradient checkpoitning

* add readme

* finish

* finish

* up

* more fixes

* up

* up

* add demo run to readme

* up
This commit is contained in:
Patrick von Platen
2021-10-12 00:46:32 +02:00
committed by GitHub
parent 3499728dc4
commit d45fc7da3d
9 changed files with 1196 additions and 183 deletions

View File

@@ -23,6 +23,7 @@ from unittest.mock import patch
import torch
from transformers import Wav2Vec2ForPreTraining
from transformers.file_utils import is_apex_available
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
@@ -41,6 +42,7 @@ SRC_DIRS = [
"image-classification",
"speech-recognition",
"audio-classification",
"speech-pretraining",
]
]
sys.path.extend(SRC_DIRS)
@@ -59,6 +61,7 @@ if SRC_DIRS is not None:
import run_summarization
import run_swag
import run_translation
import run_wav2vec2_pretraining_no_trainer
logging.basicConfig(level=logging.DEBUG)
@@ -447,3 +450,32 @@ class ExamplesTests(TestCasePlus):
run_audio_classification.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_loss"], result["train_loss"])
def test_run_wav2vec2_pretraining(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_wav2vec2_pretraining_no_trainer.py
--output_dir {tmp_dir}
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
--dataset_name patrickvonplaten/librispeech_asr_dummy
--dataset_config_names clean
--dataset_split_names validation
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--preprocessing_num_workers 16
--max_train_steps 5
--validation_split_percentage 5
--seed 42
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_wav2vec2_pretraining_no_trainer.main()
model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
self.assertIsNotNone(model)