Add semantic script, trainer (#16834)
* Add first draft * Improve script and README * Improve README * Apply suggestions from code review * Improve script, add link to resulting model * Add corresponding test * Adjust learning rate
This commit is contained in:
@@ -19,7 +19,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -45,6 +44,7 @@ SRC_DIRS = [
|
||||
"audio-classification",
|
||||
"speech-pretraining",
|
||||
"image-pretraining",
|
||||
"semantic-segmentation",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@@ -60,6 +60,7 @@ if SRC_DIRS is not None:
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_qa as run_squad
|
||||
import run_semantic_segmentation
|
||||
import run_seq2seq_qa as run_squad_seq2seq
|
||||
import run_speech_recognition_ctc
|
||||
import run_speech_recognition_seq2seq
|
||||
@@ -386,7 +387,6 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
|
||||
@unittest.skip("This is currently broken.")
|
||||
def test_run_image_classification(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@@ -534,7 +534,6 @@ class ExamplesTests(TestCasePlus):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip("This is currently broken.")
|
||||
def test_run_vit_mae_pretraining(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@@ -562,3 +561,28 @@ class ExamplesTests(TestCasePlus):
|
||||
run_mae.main()
|
||||
model = ViTMAEForPreTraining.from_pretrained(tmp_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_run_semantic_segmentation(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_semantic_segmentation.py
|
||||
--output_dir {tmp_dir}
|
||||
--dataset_name huggingface/semantic-segmentation-test-sample
|
||||
--do_train
|
||||
--do_eval
|
||||
--remove_unused_columns False
|
||||
--overwrite_output_dir True
|
||||
--max_steps 10
|
||||
--learning_rate=2e-4
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
--seed 32
|
||||
""".split()
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_semantic_segmentation.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.1)
|
||||
|
||||
Reference in New Issue
Block a user