replace no_cuda with use_cpu in test_pytorch_examples (#24944)
* replace no_cuda with use_cpu in test_pytorch_examples * remove codes that never be used * fix style
This commit is contained in:
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -76,13 +75,6 @@ logging.basicConfig(level=logging.DEBUG)
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def get_setup_file():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("-f")
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args.f
|
|
||||||
|
|
||||||
|
|
||||||
def get_results(output_dir):
|
def get_results(output_dir):
|
||||||
results = {}
|
results = {}
|
||||||
path = os.path.join(output_dir, "all_results.json")
|
path = os.path.join(output_dir, "all_results.json")
|
||||||
@@ -153,8 +145,8 @@ class ExamplesTests(TestCasePlus):
|
|||||||
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
|
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
|
||||||
return
|
return
|
||||||
|
|
||||||
if torch_device != "cuda":
|
if torch_device == "cpu":
|
||||||
testargs.append("--no_cuda")
|
testargs.append("--use_cpu")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_clm.main()
|
run_clm.main()
|
||||||
@@ -175,8 +167,8 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--config_overrides n_embd=10,n_head=2
|
--config_overrides n_embd=10,n_head=2
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if torch_device != "cuda":
|
if torch_device == "cpu":
|
||||||
testargs.append("--no_cuda")
|
testargs.append("--use_cpu")
|
||||||
|
|
||||||
logger = run_clm.logger
|
logger = run_clm.logger
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
@@ -201,8 +193,8 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--num_train_epochs=1
|
--num_train_epochs=1
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if torch_device != "cuda":
|
if torch_device == "cpu":
|
||||||
testargs.append("--no_cuda")
|
testargs.append("--use_cpu")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_mlm.main()
|
run_mlm.main()
|
||||||
@@ -231,8 +223,8 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--seed 7
|
--seed 7
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if torch_device != "cuda":
|
if torch_device == "cpu":
|
||||||
testargs.append("--no_cuda")
|
testargs.append("--use_cpu")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_ner.main()
|
run_ner.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user