[Speech Examples] Add pytorch speech pretraining (#13877)
* adapt wav2vec2 * add example * add files * adapt * remove bogus file * Apply suggestions from code review * adapt files more * upload changes * del old files * up * up * up * up * up * correct gradient checkpoitning * add readme * finish * finish * up * more fixes * up * up * add demo run to readme * up
This commit is contained in:
committed by
GitHub
parent
3499728dc4
commit
d45fc7da3d
@@ -23,6 +23,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2ForPreTraining
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
|
||||
|
||||
@@ -41,6 +42,7 @@ SRC_DIRS = [
|
||||
"image-classification",
|
||||
"speech-recognition",
|
||||
"audio-classification",
|
||||
"speech-pretraining",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@@ -59,6 +61,7 @@ if SRC_DIRS is not None:
|
||||
import run_summarization
|
||||
import run_swag
|
||||
import run_translation
|
||||
import run_wav2vec2_pretraining_no_trainer
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -447,3 +450,32 @@ class ExamplesTests(TestCasePlus):
|
||||
run_audio_classification.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_wav2vec2_pretraining(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_wav2vec2_pretraining_no_trainer.py
|
||||
--output_dir {tmp_dir}
|
||||
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
|
||||
--dataset_name patrickvonplaten/librispeech_asr_dummy
|
||||
--dataset_config_names clean
|
||||
--dataset_split_names validation
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 2
|
||||
--preprocessing_num_workers 16
|
||||
--max_train_steps 5
|
||||
--validation_split_percentage 5
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_wav2vec2_pretraining_no_trainer.main()
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
Reference in New Issue
Block a user