Re-enable squad test (#23912)
* Re-enable squad test * [all-test] * [all-test] Fix all test command * Fix the all-test
This commit is contained in:
@@ -19,7 +19,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -240,7 +239,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||
self.assertLess(result["eval_loss"], 0.5)
|
||||
|
||||
@unittest.skip("Broken, fix me Sourab")
|
||||
def test_run_squad(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
|
||||
@@ -852,10 +852,10 @@ if __name__ == "__main__":
|
||||
|
||||
if commit_flags["test_all"]:
|
||||
with open(args.output_file, "w", encoding="utf-8") as f:
|
||||
if args.filters is None:
|
||||
f.write("./tests/")
|
||||
else:
|
||||
f.write(" ".join(args.filters))
|
||||
f.write("tests")
|
||||
example_file = Path(args.output_file).parent / "examples_test_list.txt"
|
||||
with open(example_file, "w", encoding="utf-8") as f:
|
||||
f.write("all")
|
||||
|
||||
test_files_to_run = get_all_tests()
|
||||
create_json_map(test_files_to_run, args.json_output_file)
|
||||
|
||||
Reference in New Issue
Block a user