Add new run_swag example (#9175)
* Add new run_swag example * Add check * Add sample * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Very important change to make Lysandre happy Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -33,6 +33,7 @@ SRC_DIRS = [
|
||||
"text-classification",
|
||||
"token-classification",
|
||||
"language-modeling",
|
||||
"multiple-choice",
|
||||
"question-answering",
|
||||
]
|
||||
]
|
||||
@@ -46,6 +47,7 @@ if SRC_DIRS is not None:
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_qa as run_squad
|
||||
import run_swag
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -216,6 +218,32 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["f1"], 30)
|
||||
self.assertGreaterEqual(result["exact"], 30)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_run_swag(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_swag.py
|
||||
--model_name_or_path bert-base-uncased
|
||||
--train_file tests/fixtures/tests_samples/swag/sample.json
|
||||
--validation_file tests/fixtures/tests_samples/swag/sample.json
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--max_steps=20
|
||||
--warmup_steps=2
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate=2e-4
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_swag.main()
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_generation(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
Reference in New Issue
Block a user