Add image classification script, no trainer (#16727)

* Add first draft

* Improve README and run fixup

* Make script aligned with other scripts, improve README

* Improve script and add test

* Remove print statement

* Apply suggestions from code review

* Add num_labels to make test pass

* Improve README
This commit is contained in:
NielsRogge
2022-04-19 16:32:08 +02:00
committed by GitHub
parent db9f189121
commit b96e82c80a
4 changed files with 604 additions and 16 deletions

View File

@@ -52,6 +52,7 @@ sys.path.extend(SRC_DIRS)
if SRC_DIRS is not None:
import run_clm_no_trainer
import run_glue_no_trainer
import run_image_classification_no_trainer
import run_mlm_no_trainer
import run_ner_no_trainer
import run_qa_no_trainer as run_squad_no_trainer
@@ -321,3 +322,25 @@ class ExamplesTestsNoTrainer(TestCasePlus):
run_semantic_segmentation_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
def test_run_image_classification_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_image_classification_no_trainer.py
--dataset_name huggingface/image-classification-test-sample
--output_dir {tmp_dir}
--num_warmup_steps=8
--learning_rate=3e-3
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--checkpointing_steps epoch
--with_tracking
--seed 42
""".split()
with patch.object(sys, "argv", testargs):
run_image_classification_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.50)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))