Add is_torch_mps_available function to utils (#24660)
* Add mps function utils * black formating * format fix * Added MPS functionality to transformers * format fix
This commit is contained in:
@@ -116,6 +116,7 @@ from .utils import (
|
||||
is_torch_cuda_available,
|
||||
is_torch_fx_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torch_mps_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
|
||||
@@ -35,6 +35,7 @@ from .utils import (
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_tpu_available,
|
||||
requires_backends,
|
||||
)
|
||||
@@ -411,6 +412,11 @@ class TrainerMemoryTracker:
|
||||
if is_torch_cuda_available():
|
||||
import torch
|
||||
|
||||
self.torch = torch
|
||||
self.gpu = {}
|
||||
elif is_torch_mps_available():
|
||||
import torch
|
||||
|
||||
self.torch = torch
|
||||
self.gpu = {}
|
||||
else:
|
||||
|
||||
@@ -161,6 +161,7 @@ from .import_utils import (
|
||||
is_torch_cuda_available,
|
||||
is_torch_fx_available,
|
||||
is_torch_fx_proxy,
|
||||
is_torch_mps_available,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
|
||||
@@ -249,6 +249,15 @@ def is_torch_cuda_available():
|
||||
return False
|
||||
|
||||
|
||||
def is_torch_mps_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if hasattr(torch.backends, "mps"):
|
||||
return torch.backends.mps.is_available()
|
||||
return False
|
||||
|
||||
|
||||
def is_torch_bf16_gpu_available():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
@@ -774,12 +774,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
uniform_init_parms = ["conv"]
|
||||
ignore_init = ["lstm"]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
if any(x in name for x in uniform_init_parms):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
elif not any([x in name for x in ignore_init]):
|
||||
elif not any(x in name for x in ignore_init):
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
|
||||
Reference in New Issue
Block a user