replace no_cuda with use_cpu in test_pytorch_examples (#24944)
* replace no_cuda with use_cpu in test_pytorch_examples * remove codes that never be used * fix style
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -76,13 +75,6 @@ logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_setup_file():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f")
|
||||
args = parser.parse_args()
|
||||
return args.f
|
||||
|
||||
|
||||
def get_results(output_dir):
|
||||
results = {}
|
||||
path = os.path.join(output_dir, "all_results.json")
|
||||
@@ -153,8 +145,8 @@ class ExamplesTests(TestCasePlus):
|
||||
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
|
||||
return
|
||||
|
||||
if torch_device != "cuda":
|
||||
testargs.append("--no_cuda")
|
||||
if torch_device == "cpu":
|
||||
testargs.append("--use_cpu")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_clm.main()
|
||||
@@ -175,8 +167,8 @@ class ExamplesTests(TestCasePlus):
|
||||
--config_overrides n_embd=10,n_head=2
|
||||
""".split()
|
||||
|
||||
if torch_device != "cuda":
|
||||
testargs.append("--no_cuda")
|
||||
if torch_device == "cpu":
|
||||
testargs.append("--use_cpu")
|
||||
|
||||
logger = run_clm.logger
|
||||
with patch.object(sys, "argv", testargs):
|
||||
@@ -201,8 +193,8 @@ class ExamplesTests(TestCasePlus):
|
||||
--num_train_epochs=1
|
||||
""".split()
|
||||
|
||||
if torch_device != "cuda":
|
||||
testargs.append("--no_cuda")
|
||||
if torch_device == "cpu":
|
||||
testargs.append("--use_cpu")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_mlm.main()
|
||||
@@ -231,8 +223,8 @@ class ExamplesTests(TestCasePlus):
|
||||
--seed 7
|
||||
""".split()
|
||||
|
||||
if torch_device != "cuda":
|
||||
testargs.append("--no_cuda")
|
||||
if torch_device == "cpu":
|
||||
testargs.append("--use_cpu")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_ner.main()
|
||||
|
||||
Reference in New Issue
Block a user