make tests of pytorch_example device agnostic (#27081)

This commit is contained in:
Hz, Ji
2023-10-30 22:56:41 +08:00
committed by GitHub
parent 6b466771b0
commit cd19b19378
2 changed files with 32 additions and 33 deletions

View File

@@ -24,11 +24,16 @@ import tempfile
import unittest
from unittest import mock
import torch
from accelerate.utils import write_basic_config
from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device
from transformers.utils import is_apex_available
from transformers.testing_utils import (
TestCasePlus,
backend_device_count,
is_torch_fp16_available_on_device,
run_command,
slow,
torch_device,
)
logging.basicConfig(level=logging.DEBUG)
@@ -54,11 +59,6 @@ def get_results(output_dir):
return results
def is_cuda_and_apex_available():
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
return is_using_cuda and is_apex_available()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
@@ -93,7 +93,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking
""".split()
if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")
run_command(self._launch_args + testargs)
@@ -119,7 +119,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking
""".split()
if torch.cuda.device_count() > 1:
if backend_device_count(torch_device) > 1:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return
@@ -152,7 +152,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_ner_no_trainer(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2
epochs = 7 if backend_device_count(torch_device) > 1 else 2
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
@@ -326,7 +326,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--checkpointing_steps 1
""".split()
if is_cuda_and_apex_available():
if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16")
run_command(self._launch_args + testargs)