Re-enable squad test (#23912)
* Re-enable squad test * [all-test] * [all-test] Fix all test command * Fix the all-test
This commit is contained in:
@@ -19,7 +19,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -240,7 +239,6 @@ class ExamplesTests(TestCasePlus):
|
|||||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||||
self.assertLess(result["eval_loss"], 0.5)
|
self.assertLess(result["eval_loss"], 0.5)
|
||||||
|
|
||||||
@unittest.skip("Broken, fix me Sourab")
|
|
||||||
def test_run_squad(self):
|
def test_run_squad(self):
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
|
|||||||
@@ -852,10 +852,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if commit_flags["test_all"]:
|
if commit_flags["test_all"]:
|
||||||
with open(args.output_file, "w", encoding="utf-8") as f:
|
with open(args.output_file, "w", encoding="utf-8") as f:
|
||||||
if args.filters is None:
|
f.write("tests")
|
||||||
f.write("./tests/")
|
example_file = Path(args.output_file).parent / "examples_test_list.txt"
|
||||||
else:
|
with open(example_file, "w", encoding="utf-8") as f:
|
||||||
f.write(" ".join(args.filters))
|
f.write("all")
|
||||||
|
|
||||||
test_files_to_run = get_all_tests()
|
test_files_to_run = get_all_tests()
|
||||||
create_json_map(test_files_to_run, args.json_output_file)
|
create_json_map(test_files_to_run, args.json_output_file)
|
||||||
|
|||||||
Reference in New Issue
Block a user