From 3ff443a6d96ec29f7e7b395db90ba4de558ac3f3 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 31 May 2023 13:44:26 -0400 Subject: [PATCH] Re-enable squad test (#23912) * Re-enable squad test * [all-test] * [all-test] Fix all test command * Fix the all-test --- examples/pytorch/test_pytorch_examples.py | 2 -- utils/tests_fetcher.py | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/test_pytorch_examples.py b/examples/pytorch/test_pytorch_examples.py index 6b992612e8..f4682b8933 100644 --- a/examples/pytorch/test_pytorch_examples.py +++ b/examples/pytorch/test_pytorch_examples.py @@ -19,7 +19,6 @@ import json import logging import os import sys -import unittest from unittest.mock import patch import torch @@ -240,7 +239,6 @@ class ExamplesTests(TestCasePlus): self.assertGreaterEqual(result["eval_accuracy"], 0.75) self.assertLess(result["eval_loss"], 0.5) - @unittest.skip("Broken, fix me Sourab") def test_run_squad(self): tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index d8373da5ef..0480e44c3e 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -852,10 +852,10 @@ if __name__ == "__main__": if commit_flags["test_all"]: with open(args.output_file, "w", encoding="utf-8") as f: - if args.filters is None: - f.write("./tests/") - else: - f.write(" ".join(args.filters)) + f.write("tests") + example_file = Path(args.output_file).parent / "examples_test_list.txt" + with open(example_file, "w", encoding="utf-8") as f: + f.write("all") test_files_to_run = get_all_tests() create_json_map(test_files_to_run, args.json_output_file)