✨ Add PyTorch image classification example (#13134)
* ✨ add pytorch image classification example * 🔥 remove utils.py * 💄 fix flake8 style issues * 🔥 remove unnecessary line * ✨ limit dataset sizes * 📌 update reqs * 🎨 restructure - use datasets lib * 🎨 import transforms directly * 📝 add comments * 💄 style * 🔥 remove flag * 📌 update requirement warning * 📝 add vision README.md * 📝 update README.md * 📝 update README.md * 🎨 add image-classification tag to model card * 🚚 rename vision ➡️ image-classification * 📝 update image-classification README.md
This commit is contained in:
@@ -38,6 +38,7 @@ SRC_DIRS = [
|
||||
"question-answering",
|
||||
"summarization",
|
||||
"translation",
|
||||
"image-classification",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@@ -47,6 +48,7 @@ if SRC_DIRS is not None:
|
||||
import run_clm
|
||||
import run_generation
|
||||
import run_glue
|
||||
import run_image_classification
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_qa as run_squad
|
||||
@@ -340,3 +342,35 @@ class ExamplesTests(TestCasePlus):
|
||||
run_translation.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
|
||||
def test_run_image_classification(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_image_classification.py
|
||||
--output_dir {tmp_dir}
|
||||
--model_name_or_path google/vit-base-patch16-224-in21k
|
||||
--train_dir tests/fixtures/tests_samples/cats_and_dogs/
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 2e-5
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 1
|
||||
--remove_unused_columns False
|
||||
--overwrite_output_dir True
|
||||
--dataloader_num_workers 16
|
||||
--metric_for_best_model accuracy
|
||||
--max_steps 30
|
||||
--train_val_split 0.1
|
||||
--seed 7
|
||||
""".split()
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_image_classification.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
|
||||
Reference in New Issue
Block a user