[ViTMAE] Add image pretraining script (#15242)

* Add script

* Improve script

* Fix data collator

* Update README

* Add label_names argument

* Apply suggestions from code review

* Add config parameters

* Update script

* Fix bug

* Improve README

* Improve README and add test

* Fix import

* Add image_column_name
This commit is contained in:
NielsRogge
2022-01-21 12:11:08 +01:00
committed by GitHub
parent d43e308e7f
commit 6c7b68d414
5 changed files with 539 additions and 3 deletions

View File

@@ -23,7 +23,7 @@ from unittest.mock import patch
import torch
from transformers import Wav2Vec2ForPreTraining
from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining
from transformers.file_utils import is_apex_available
from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device
@@ -43,6 +43,7 @@ SRC_DIRS = [
"speech-recognition",
"audio-classification",
"speech-pretraining",
"image-pretraining",
]
]
sys.path.extend(SRC_DIRS)
@@ -54,6 +55,7 @@ if SRC_DIRS is not None:
import run_generation
import run_glue
import run_image_classification
import run_mae
import run_mlm
import run_ner
import run_qa as run_squad
@@ -570,3 +572,34 @@ class ExamplesTests(TestCasePlus):
run_wav2vec2_pretraining_no_trainer.main()
model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
self.assertIsNotNone(model)
def test_run_vit_mae_pretraining(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_mae.py
--output_dir {tmp_dir}
--dataset_name hf-internal-testing/cats_vs_dogs_sample
--do_train
--do_eval
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 1
--remove_unused_columns False
--overwrite_output_dir True
--dataloader_num_workers 16
--metric_for_best_model accuracy
--max_steps 10
--train_val_split 0.1
--seed 42
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_mae.main()
model = ViTMAEForPreTraining.from_pretrained(tmp_dir)
self.assertIsNotNone(model)