Add image classification script, no trainer (#16727)
* Add first draft * Improve README and run fixup * Make script aligned with other scripts, improve README * Improve script and add test * Remove print statement * Apply suggestions from code review * Add num_labels to make test pass * Improve README
This commit is contained in:
@@ -52,6 +52,7 @@ sys.path.extend(SRC_DIRS)
|
||||
if SRC_DIRS is not None:
|
||||
import run_clm_no_trainer
|
||||
import run_glue_no_trainer
|
||||
import run_image_classification_no_trainer
|
||||
import run_mlm_no_trainer
|
||||
import run_ner_no_trainer
|
||||
import run_qa_no_trainer as run_squad_no_trainer
|
||||
@@ -321,3 +322,25 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
run_semantic_segmentation_no_trainer.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
|
||||
|
||||
def test_run_image_classification_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_image_classification_no_trainer.py
|
||||
--dataset_name huggingface/image-classification-test-sample
|
||||
--output_dir {tmp_dir}
|
||||
--num_warmup_steps=8
|
||||
--learning_rate=3e-3
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
--checkpointing_steps epoch
|
||||
--with_tracking
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_image_classification_no_trainer.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.50)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
|
||||
|
||||
Reference in New Issue
Block a user