make tests of pytorch_example device agnostic (#27081)
This commit is contained in:
@@ -24,11 +24,16 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import torch
|
|
||||||
from accelerate.utils import write_basic_config
|
from accelerate.utils import write_basic_config
|
||||||
|
|
||||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
from transformers.utils import is_apex_available
|
TestCasePlus,
|
||||||
|
backend_device_count,
|
||||||
|
is_torch_fp16_available_on_device,
|
||||||
|
run_command,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@@ -54,11 +59,6 @@ def get_results(output_dir):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def is_cuda_and_apex_available():
|
|
||||||
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
|
|
||||||
return is_using_cuda and is_apex_available()
|
|
||||||
|
|
||||||
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
--with_tracking
|
--with_tracking
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
run_command(self._launch_args + testargs)
|
run_command(self._launch_args + testargs)
|
||||||
@@ -119,7 +119,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
--with_tracking
|
--with_tracking
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if backend_device_count(torch_device) > 1:
|
||||||
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
|
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||||
def test_run_ner_no_trainer(self):
|
def test_run_ner_no_trainer(self):
|
||||||
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||||
epochs = 7 if get_gpu_count() > 1 else 2
|
epochs = 7 if backend_device_count(torch_device) > 1 else 2
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
@@ -326,7 +326,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
|||||||
--checkpointing_steps 1
|
--checkpointing_steps 1
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
run_command(self._launch_args + testargs)
|
run_command(self._launch_args + testargs)
|
||||||
|
|||||||
@@ -20,11 +20,15 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining
|
from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining
|
||||||
from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
from transformers.utils import is_apex_available
|
CaptureLogger,
|
||||||
|
TestCasePlus,
|
||||||
|
backend_device_count,
|
||||||
|
is_torch_fp16_available_on_device,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SRC_DIRS = [
|
SRC_DIRS = [
|
||||||
@@ -86,11 +90,6 @@ def get_results(output_dir):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def is_cuda_and_apex_available():
|
|
||||||
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
|
|
||||||
return is_using_cuda and is_apex_available()
|
|
||||||
|
|
||||||
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
@@ -116,7 +115,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--max_seq_length=128
|
--max_seq_length=128
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
@@ -141,7 +140,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if backend_device_count(torch_device) > 1:
|
||||||
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
|
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -203,7 +202,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
def test_run_ner(self):
|
def test_run_ner(self):
|
||||||
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||||
epochs = 7 if get_gpu_count() > 1 else 2
|
epochs = 7 if backend_device_count(torch_device) > 1 else 2
|
||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
@@ -312,7 +311,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
def test_generation(self):
|
def test_generation(self):
|
||||||
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
|
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
model_type, model_name = (
|
model_type, model_name = (
|
||||||
@@ -401,7 +400,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--seed 42
|
--seed 42
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
@@ -431,7 +430,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--seed 42
|
--seed 42
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
@@ -462,7 +461,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--seed 42
|
--seed 42
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
@@ -493,7 +492,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--seed 42
|
--seed 42
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
@@ -525,7 +524,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--seed 42
|
--seed 42
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
@@ -551,7 +550,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--seed 42
|
--seed 42
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
@@ -579,7 +578,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--seed 42
|
--seed 42
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
@@ -604,7 +603,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--seed 32
|
--seed 32
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if is_cuda_and_apex_available():
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
testargs.append("--fp16")
|
testargs.append("--fp16")
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user