Add semantic script, trainer (#16834)

* Add first draft

* Improve script and README

* Improve README

* Apply suggestions from code review

* Improve script, add link to resulting model

* Add corresponding test

* Adjust learning rate
This commit is contained in:
NielsRogge
2022-04-27 10:12:18 +02:00
committed by GitHub
parent a4a88fa09f
commit 479fdc4925
5 changed files with 603 additions and 19 deletions

View File

@@ -19,7 +19,6 @@ import json
import logging
import os
import sys
import unittest
from unittest.mock import patch
import torch
@@ -45,6 +44,7 @@ SRC_DIRS = [
"audio-classification",
"speech-pretraining",
"image-pretraining",
"semantic-segmentation",
]
]
sys.path.extend(SRC_DIRS)
@@ -60,6 +60,7 @@ if SRC_DIRS is not None:
import run_mlm
import run_ner
import run_qa as run_squad
import run_semantic_segmentation
import run_seq2seq_qa as run_squad_seq2seq
import run_speech_recognition_ctc
import run_speech_recognition_seq2seq
@@ -386,7 +387,6 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30)
@unittest.skip("This is currently broken.")
def test_run_image_classification(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
@@ -534,7 +534,6 @@ class ExamplesTests(TestCasePlus):
model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
self.assertIsNotNone(model)
@unittest.skip("This is currently broken.")
def test_run_vit_mae_pretraining(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
@@ -562,3 +561,28 @@ class ExamplesTests(TestCasePlus):
run_mae.main()
model = ViTMAEForPreTraining.from_pretrained(tmp_dir)
self.assertIsNotNone(model)
def test_run_semantic_segmentation(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_semantic_segmentation.py
--output_dir {tmp_dir}
--dataset_name huggingface/semantic-segmentation-test-sample
--do_train
--do_eval
--remove_unused_columns False
--overwrite_output_dir True
--max_steps 10
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--seed 32
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_semantic_segmentation.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.1)