From cd19b193785fd9d224b8d2b509c63387bb48bc14 Mon Sep 17 00:00:00 2001 From: "Hz, Ji" Date: Mon, 30 Oct 2023 22:56:41 +0800 Subject: [PATCH] make tests of pytorch_example device agnostic (#27081) --- examples/pytorch/test_accelerate_examples.py | 24 ++++++------ examples/pytorch/test_pytorch_examples.py | 41 ++++++++++---------- 2 files changed, 32 insertions(+), 33 deletions(-) diff --git a/examples/pytorch/test_accelerate_examples.py b/examples/pytorch/test_accelerate_examples.py index 4cfe45b022..d5e20d820e 100644 --- a/examples/pytorch/test_accelerate_examples.py +++ b/examples/pytorch/test_accelerate_examples.py @@ -24,11 +24,16 @@ import tempfile import unittest from unittest import mock -import torch from accelerate.utils import write_basic_config -from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device -from transformers.utils import is_apex_available +from transformers.testing_utils import ( + TestCasePlus, + backend_device_count, + is_torch_fp16_available_on_device, + run_command, + slow, + torch_device, +) logging.basicConfig(level=logging.DEBUG) @@ -54,11 +59,6 @@ def get_results(output_dir): return results -def is_cuda_and_apex_available(): - is_using_cuda = torch.cuda.is_available() and torch_device == "cuda" - return is_using_cuda and is_apex_available() - - stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) @@ -93,7 +93,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): --with_tracking """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") run_command(self._launch_args + testargs) @@ -119,7 +119,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): --with_tracking """.split() - if torch.cuda.device_count() > 1: + if backend_device_count(torch_device) > 1: # Skipping because there are not enough batches to train the model + would need a drop_last to work. return @@ -152,7 +152,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"}) def test_run_ner_no_trainer(self): # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu - epochs = 7 if get_gpu_count() > 1 else 2 + epochs = 7 if backend_device_count(torch_device) > 1 else 2 tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" @@ -326,7 +326,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): --checkpointing_steps 1 """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") run_command(self._launch_args + testargs) diff --git a/examples/pytorch/test_pytorch_examples.py b/examples/pytorch/test_pytorch_examples.py index 269d7844f7..7d27804a73 100644 --- a/examples/pytorch/test_pytorch_examples.py +++ b/examples/pytorch/test_pytorch_examples.py @@ -20,11 +20,15 @@ import os import sys from unittest.mock import patch -import torch - from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining -from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device -from transformers.utils import is_apex_available +from transformers.testing_utils import ( + CaptureLogger, + TestCasePlus, + backend_device_count, + is_torch_fp16_available_on_device, + slow, + torch_device, +) SRC_DIRS = [ @@ -86,11 +90,6 @@ def get_results(output_dir): return results -def is_cuda_and_apex_available(): - is_using_cuda = torch.cuda.is_available() and torch_device == "cuda" - return is_using_cuda and is_apex_available() - - stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) @@ -116,7 +115,7 @@ class ExamplesTests(TestCasePlus): --max_seq_length=128 """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") with patch.object(sys, "argv", testargs): @@ -141,7 +140,7 @@ class ExamplesTests(TestCasePlus): --overwrite_output_dir """.split() - if torch.cuda.device_count() > 1: + if backend_device_count(torch_device) > 1: # Skipping because there are not enough batches to train the model + would need a drop_last to work. return @@ -203,7 +202,7 @@ class ExamplesTests(TestCasePlus): def test_run_ner(self): # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu - epochs = 7 if get_gpu_count() > 1 else 2 + epochs = 7 if backend_device_count(torch_device) > 1 else 2 tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" @@ -312,7 +311,7 @@ class ExamplesTests(TestCasePlus): def test_generation(self): testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"] - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") model_type, model_name = ( @@ -401,7 +400,7 @@ class ExamplesTests(TestCasePlus): --seed 42 """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") with patch.object(sys, "argv", testargs): @@ -431,7 +430,7 @@ class ExamplesTests(TestCasePlus): --seed 42 """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") with patch.object(sys, "argv", testargs): @@ -462,7 +461,7 @@ class ExamplesTests(TestCasePlus): --seed 42 """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") with patch.object(sys, "argv", testargs): @@ -493,7 +492,7 @@ class ExamplesTests(TestCasePlus): --seed 42 """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") with patch.object(sys, "argv", testargs): @@ -525,7 +524,7 @@ class ExamplesTests(TestCasePlus): --seed 42 """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") with patch.object(sys, "argv", testargs): @@ -551,7 +550,7 @@ class ExamplesTests(TestCasePlus): --seed 42 """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") with patch.object(sys, "argv", testargs): @@ -579,7 +578,7 @@ class ExamplesTests(TestCasePlus): --seed 42 """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") with patch.object(sys, "argv", testargs): @@ -604,7 +603,7 @@ class ExamplesTests(TestCasePlus): --seed 32 """.split() - if is_cuda_and_apex_available(): + if is_torch_fp16_available_on_device(torch_device): testargs.append("--fp16") with patch.object(sys, "argv", testargs):