Add is_torch_mps_available function to utils (#24660)
* Add mps function utils * black formating * format fix * Added MPS functionality to transformers * format fix
This commit is contained in:
@@ -774,12 +774,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
uniform_init_parms = ["conv"]
|
||||
ignore_init = ["lstm"]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
if any(x in name for x in uniform_init_parms):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
elif not any([x in name for x in ignore_init]):
|
||||
elif not any(x in name for x in ignore_init):
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
|
||||
Reference in New Issue
Block a user