Add semantic script no trainer, v2 (#16788)
* Add first draft from previous PR * First draft * Improve README and remove num_labels * Make script more aligned with other scripts * Improve README and apply suggestion from code review
This commit is contained in:
@@ -43,6 +43,7 @@ SRC_DIRS = [
|
||||
"audio-classification",
|
||||
"speech-pretraining",
|
||||
"image-pretraining",
|
||||
"semantic-segmentation",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@@ -54,6 +55,7 @@ if SRC_DIRS is not None:
|
||||
import run_mlm_no_trainer
|
||||
import run_ner_no_trainer
|
||||
import run_qa_no_trainer as run_squad_no_trainer
|
||||
import run_semantic_segmentation_no_trainer
|
||||
import run_summarization_no_trainer
|
||||
import run_swag_no_trainer
|
||||
import run_translation_no_trainer
|
||||
@@ -296,3 +298,26 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "translation_no_trainer")))
|
||||
|
||||
@slow
|
||||
def test_run_semantic_segmentation_no_trainer(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_semantic_segmentation_no_trainer.py
|
||||
--dataset_name huggingface/semantic-segmentation-test-sample
|
||||
--output_dir {tmp_dir}
|
||||
--max_train_steps=10
|
||||
--num_warmup_steps=2
|
||||
--learning_rate=2e-4
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
--checkpointing_steps epoch
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_semantic_segmentation_no_trainer.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
|
||||
|
||||
Reference in New Issue
Block a user