From c3e607496c28b6e4c41a0aeb2a4c465b4c07f66a Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 16 Jun 2020 14:06:45 -0400 Subject: [PATCH] [cleanup] examples test_run_squad uses tiny model (#5059) --- examples/test_examples.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index af1897738f..54431316e7 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -55,7 +55,7 @@ class ExamplesTests(unittest.TestCase): testargs = """ run_glue.py - --model_name_or_path bert-base-uncased + --model_name_or_path distilbert-base-uncased --data_dir ./tests/fixtures/tests_samples/MRPC/ --task_name mrpc --do_train @@ -79,6 +79,7 @@ class ExamplesTests(unittest.TestCase): def test_run_language_modeling(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) + # TODO: switch to smaller model like sshleifer/tiny-distilroberta-base testargs = """ run_language_modeling.py @@ -105,10 +106,9 @@ class ExamplesTests(unittest.TestCase): testargs = """ run_squad.py - --model_type=bert - --model_name_or_path=bert-base-uncased + --model_type=distilbert + --model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad --data_dir=./tests/fixtures/tests_samples/SQUAD - --model_name=bert-base-uncased --output_dir=./tests/fixtures/tests_samples/temp_dir --max_steps=10 --warmup_steps=2 @@ -123,15 +123,15 @@ class ExamplesTests(unittest.TestCase): """.split() with patch.object(sys, "argv", testargs): result = run_squad.main() - self.assertGreaterEqual(result["f1"], 30) - self.assertGreaterEqual(result["exact"], 30) + self.assertGreaterEqual(result["f1"], 25) + self.assertGreaterEqual(result["exact"], 21) def test_generation(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"] - model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt") + model_type, model_name = ("--model_type=gpt2", "--model_name_or_path=sshleifer/tiny-gpt2") with patch.object(sys, "argv", testargs + [model_type, model_name]): result = run_generation.main() self.assertGreaterEqual(len(result[0]), 10)