diff --git a/examples/pytorch/test_pytorch_examples.py b/examples/pytorch/test_pytorch_examples.py index a94e23e614..269d7844f7 100644 --- a/examples/pytorch/test_pytorch_examples.py +++ b/examples/pytorch/test_pytorch_examples.py @@ -14,7 +14,6 @@ # limitations under the License. -import argparse import json import logging import os @@ -76,13 +75,6 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() -def get_setup_file(): - parser = argparse.ArgumentParser() - parser.add_argument("-f") - args = parser.parse_args() - return args.f - - def get_results(output_dir): results = {} path = os.path.join(output_dir, "all_results.json") @@ -153,8 +145,8 @@ class ExamplesTests(TestCasePlus): # Skipping because there are not enough batches to train the model + would need a drop_last to work. return - if torch_device != "cuda": - testargs.append("--no_cuda") + if torch_device == "cpu": + testargs.append("--use_cpu") with patch.object(sys, "argv", testargs): run_clm.main() @@ -175,8 +167,8 @@ class ExamplesTests(TestCasePlus): --config_overrides n_embd=10,n_head=2 """.split() - if torch_device != "cuda": - testargs.append("--no_cuda") + if torch_device == "cpu": + testargs.append("--use_cpu") logger = run_clm.logger with patch.object(sys, "argv", testargs): @@ -201,8 +193,8 @@ class ExamplesTests(TestCasePlus): --num_train_epochs=1 """.split() - if torch_device != "cuda": - testargs.append("--no_cuda") + if torch_device == "cpu": + testargs.append("--use_cpu") with patch.object(sys, "argv", testargs): run_mlm.main() @@ -231,8 +223,8 @@ class ExamplesTests(TestCasePlus): --seed 7 """.split() - if torch_device != "cuda": - testargs.append("--no_cuda") + if torch_device == "cpu": + testargs.append("--use_cpu") with patch.object(sys, "argv", testargs): run_ner.main()