Added max_sample_ arguments (#10551)
* reverted changes of logging and saving metrics * added max_sample arguments * fixed code * white space diff * reformetting code * reformatted code
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -64,6 +65,17 @@ def get_setup_file():
|
||||
return args.f
|
||||
|
||||
|
||||
def get_results(output_dir):
|
||||
results = {}
|
||||
path = os.path.join(output_dir, "all_results.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
results = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"can't find {path}")
|
||||
return results
|
||||
|
||||
|
||||
def is_cuda_and_apex_available():
|
||||
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
|
||||
return is_using_cuda and is_apex_available()
|
||||
@@ -98,7 +110,8 @@ class ExamplesTests(TestCasePlus):
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_glue.main()
|
||||
run_glue.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
@@ -130,7 +143,8 @@ class ExamplesTests(TestCasePlus):
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_clm.main()
|
||||
run_clm.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["perplexity"], 100)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
@@ -156,7 +170,8 @@ class ExamplesTests(TestCasePlus):
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_mlm.main()
|
||||
run_mlm.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["perplexity"], 42)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
@@ -185,7 +200,8 @@ class ExamplesTests(TestCasePlus):
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_ner.main()
|
||||
run_ner.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
self.assertGreaterEqual(result["eval_precision"], 0.75)
|
||||
self.assertLess(result["eval_loss"], 0.5)
|
||||
@@ -214,7 +230,8 @@ class ExamplesTests(TestCasePlus):
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_squad.main()
|
||||
run_squad.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["f1"], 30)
|
||||
self.assertGreaterEqual(result["exact"], 30)
|
||||
|
||||
@@ -241,7 +258,8 @@ class ExamplesTests(TestCasePlus):
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_swag.main()
|
||||
run_swag.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
@@ -288,8 +306,8 @@ class ExamplesTests(TestCasePlus):
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_seq2seq.main()
|
||||
|
||||
run_seq2seq.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_rouge1"], 10)
|
||||
self.assertGreaterEqual(result["eval_rouge2"], 2)
|
||||
self.assertGreaterEqual(result["eval_rougeL"], 7)
|
||||
@@ -323,5 +341,6 @@ class ExamplesTests(TestCasePlus):
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
result = run_seq2seq.main()
|
||||
run_seq2seq.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
|
||||
Reference in New Issue
Block a user