Add is_torch_mps_available function to utils (#24660)
* Add mps function utils * black formating * format fix * Added MPS functionality to transformers * format fix
This commit is contained in:
@@ -116,6 +116,7 @@ from .utils import (
|
|||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_fx_available,
|
is_torch_fx_available,
|
||||||
is_torch_fx_proxy,
|
is_torch_fx_proxy,
|
||||||
|
is_torch_mps_available,
|
||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from .utils import (
|
|||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
|
is_torch_mps_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
@@ -411,6 +412,11 @@ class TrainerMemoryTracker:
|
|||||||
if is_torch_cuda_available():
|
if is_torch_cuda_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
self.torch = torch
|
||||||
|
self.gpu = {}
|
||||||
|
elif is_torch_mps_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
self.torch = torch
|
self.torch = torch
|
||||||
self.gpu = {}
|
self.gpu = {}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ from .import_utils import (
|
|||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_fx_available,
|
is_torch_fx_available,
|
||||||
is_torch_fx_proxy,
|
is_torch_fx_proxy,
|
||||||
|
is_torch_mps_available,
|
||||||
is_torch_neuroncore_available,
|
is_torch_neuroncore_available,
|
||||||
is_torch_tensorrt_fx_available,
|
is_torch_tensorrt_fx_available,
|
||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
|
|||||||
@@ -249,6 +249,15 @@ def is_torch_cuda_available():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_mps_available():
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if hasattr(torch.backends, "mps"):
|
||||||
|
return torch.backends.mps.is_available()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_torch_bf16_gpu_available():
|
def is_torch_bf16_gpu_available():
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -774,12 +774,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
uniform_init_parms = ["conv"]
|
uniform_init_parms = ["conv"]
|
||||||
ignore_init = ["lstm"]
|
ignore_init = ["lstm"]
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if any([x in name for x in uniform_init_parms]):
|
if any(x in name for x in uniform_init_parms):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
elif not any([x in name for x in ignore_init]):
|
elif not any(x in name for x in ignore_init):
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
[0.0, 1.0],
|
[0.0, 1.0],
|
||||||
|
|||||||
Reference in New Issue
Block a user