Add is_torch_mps_available function to utils (#24660)

* Add mps function utils

* black formating

* format fix

* Added MPS functionality to transformers

* format fix
This commit is contained in:
Nripesh Niketan
2023-07-05 19:32:20 +05:30
committed by GitHub
parent ee339bad01
commit bd9dfc23b9
5 changed files with 19 additions and 2 deletions

View File

@@ -774,12 +774,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
uniform_init_parms = ["conv"]
ignore_init = ["lstm"]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
if any(x in name for x in uniform_init_parms):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
elif not any([x in name for x in ignore_init]):
elif not any(x in name for x in ignore_init):
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],