From bd9dfc23b9954b3b623425b31be181698d916a09 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Wed, 5 Jul 2023 19:32:20 +0530 Subject: [PATCH] Add `is_torch_mps_available` function to utils (#24660) * Add mps function utils * black formating * format fix * Added MPS functionality to transformers * format fix --- src/transformers/file_utils.py | 1 + src/transformers/trainer_utils.py | 6 ++++++ src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 9 +++++++++ tests/models/musicgen/test_modeling_musicgen.py | 4 ++-- 5 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 38f4db0581..63230b5b84 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -116,6 +116,7 @@ from .utils import ( is_torch_cuda_available, is_torch_fx_available, is_torch_fx_proxy, + is_torch_mps_available, is_torch_tf32_available, is_torch_tpu_available, is_torchaudio_available, diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index a497ef1ee9..74f01ad927 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -35,6 +35,7 @@ from .utils import ( is_tf_available, is_torch_available, is_torch_cuda_available, + is_torch_mps_available, is_torch_tpu_available, requires_backends, ) @@ -411,6 +412,11 @@ class TrainerMemoryTracker: if is_torch_cuda_available(): import torch + self.torch = torch + self.gpu = {} + elif is_torch_mps_available(): + import torch + self.torch = torch self.gpu = {} else: diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 764ca38b72..4d27e3084b 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -161,6 +161,7 @@ from .import_utils import ( is_torch_cuda_available, is_torch_fx_available, is_torch_fx_proxy, + is_torch_mps_available, is_torch_neuroncore_available, is_torch_tensorrt_fx_available, is_torch_tf32_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ed356a05b7..27700c6598 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -249,6 +249,15 @@ def is_torch_cuda_available(): return False +def is_torch_mps_available(): + if is_torch_available(): + import torch + + if hasattr(torch.backends, "mps"): + return torch.backends.mps.is_available() + return False + + def is_torch_bf16_gpu_available(): if not is_torch_available(): return False diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 00f249b09d..3d59becc75 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -774,12 +774,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, uniform_init_parms = ["conv"] ignore_init = ["lstm"] if param.requires_grad: - if any([x in name for x in uniform_init_parms]): + if any(x in name for x in uniform_init_parms): self.assertTrue( -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - elif not any([x in name for x in ignore_init]): + elif not any(x in name for x in ignore_init): self.assertIn( ((param.data.mean() * 1e9).round() / 1e9).item(), [0.0, 1.0],