[ViTMAE] Add image pretraining script (#15242)
* Add script * Improve script * Fix data collator * Update README * Add label_names argument * Apply suggestions from code review * Add config parameters * Update script * Fix bug * Improve README * Improve README and add test * Fix import * Add image_column_name
This commit is contained in:
@@ -23,7 +23,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2ForPreTraining
|
||||
from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device
|
||||
|
||||
@@ -43,6 +43,7 @@ SRC_DIRS = [
|
||||
"speech-recognition",
|
||||
"audio-classification",
|
||||
"speech-pretraining",
|
||||
"image-pretraining",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@@ -54,6 +55,7 @@ if SRC_DIRS is not None:
|
||||
import run_generation
|
||||
import run_glue
|
||||
import run_image_classification
|
||||
import run_mae
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_qa as run_squad
|
||||
@@ -570,3 +572,34 @@ class ExamplesTests(TestCasePlus):
|
||||
run_wav2vec2_pretraining_no_trainer.main()
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_run_vit_mae_pretraining(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_mae.py
|
||||
--output_dir {tmp_dir}
|
||||
--dataset_name hf-internal-testing/cats_vs_dogs_sample
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 1
|
||||
--remove_unused_columns False
|
||||
--overwrite_output_dir True
|
||||
--dataloader_num_workers 16
|
||||
--metric_for_best_model accuracy
|
||||
--max_steps 10
|
||||
--train_val_split 0.1
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_mae.main()
|
||||
model = ViTMAEForPreTraining.from_pretrained(tmp_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
Reference in New Issue
Block a user