[cleanup] examples test_run_squad uses tiny model (#5059)
This commit is contained in:
@@ -55,7 +55,7 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
|
|
||||||
testargs = """
|
testargs = """
|
||||||
run_glue.py
|
run_glue.py
|
||||||
--model_name_or_path bert-base-uncased
|
--model_name_or_path distilbert-base-uncased
|
||||||
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
--data_dir ./tests/fixtures/tests_samples/MRPC/
|
||||||
--task_name mrpc
|
--task_name mrpc
|
||||||
--do_train
|
--do_train
|
||||||
@@ -79,6 +79,7 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
def test_run_language_modeling(self):
|
def test_run_language_modeling(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
# TODO: switch to smaller model like sshleifer/tiny-distilroberta-base
|
||||||
|
|
||||||
testargs = """
|
testargs = """
|
||||||
run_language_modeling.py
|
run_language_modeling.py
|
||||||
@@ -105,10 +106,9 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
|
|
||||||
testargs = """
|
testargs = """
|
||||||
run_squad.py
|
run_squad.py
|
||||||
--model_type=bert
|
--model_type=distilbert
|
||||||
--model_name_or_path=bert-base-uncased
|
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
|
||||||
--data_dir=./tests/fixtures/tests_samples/SQUAD
|
--data_dir=./tests/fixtures/tests_samples/SQUAD
|
||||||
--model_name=bert-base-uncased
|
|
||||||
--output_dir=./tests/fixtures/tests_samples/temp_dir
|
--output_dir=./tests/fixtures/tests_samples/temp_dir
|
||||||
--max_steps=10
|
--max_steps=10
|
||||||
--warmup_steps=2
|
--warmup_steps=2
|
||||||
@@ -123,15 +123,15 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
""".split()
|
""".split()
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
result = run_squad.main()
|
result = run_squad.main()
|
||||||
self.assertGreaterEqual(result["f1"], 30)
|
self.assertGreaterEqual(result["f1"], 25)
|
||||||
self.assertGreaterEqual(result["exact"], 30)
|
self.assertGreaterEqual(result["exact"], 21)
|
||||||
|
|
||||||
def test_generation(self):
|
def test_generation(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
|
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
|
||||||
model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt")
|
model_type, model_name = ("--model_type=gpt2", "--model_name_or_path=sshleifer/tiny-gpt2")
|
||||||
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||||
result = run_generation.main()
|
result = run_generation.main()
|
||||||
self.assertGreaterEqual(len(result[0]), 10)
|
self.assertGreaterEqual(len(result[0]), 10)
|
||||||
|
|||||||
Reference in New Issue
Block a user