replace no_cuda with use_cpu in test_pytorch_examples (#24944)

* replace no_cuda with use_cpu in test_pytorch_examples

* remove codes that never be used

* fix style
This commit is contained in:
statelesshz
2023-07-20 19:09:04 +08:00
committed by GitHub
parent 79444f370f
commit 37d8611ac9

View File

@@ -14,7 +14,6 @@
# limitations under the License.
import argparse
import json
import logging
import os
@@ -76,13 +75,6 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument("-f")
args = parser.parse_args()
return args.f
def get_results(output_dir):
results = {}
path = os.path.join(output_dir, "all_results.json")
@@ -153,8 +145,8 @@ class ExamplesTests(TestCasePlus):
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return
if torch_device != "cuda":
testargs.append("--no_cuda")
if torch_device == "cpu":
testargs.append("--use_cpu")
with patch.object(sys, "argv", testargs):
run_clm.main()
@@ -175,8 +167,8 @@ class ExamplesTests(TestCasePlus):
--config_overrides n_embd=10,n_head=2
""".split()
if torch_device != "cuda":
testargs.append("--no_cuda")
if torch_device == "cpu":
testargs.append("--use_cpu")
logger = run_clm.logger
with patch.object(sys, "argv", testargs):
@@ -201,8 +193,8 @@ class ExamplesTests(TestCasePlus):
--num_train_epochs=1
""".split()
if torch_device != "cuda":
testargs.append("--no_cuda")
if torch_device == "cpu":
testargs.append("--use_cpu")
with patch.object(sys, "argv", testargs):
run_mlm.main()
@@ -231,8 +223,8 @@ class ExamplesTests(TestCasePlus):
--seed 7
""".split()
if torch_device != "cuda":
testargs.append("--no_cuda")
if torch_device == "cpu":
testargs.append("--use_cpu")
with patch.object(sys, "argv", testargs):
run_ner.main()