Pytorch 1.5.0 (#3973)

* Standard deviation can no longer be set to 0

* Remove torch pinned version

* 9th instead of 10th, silly me
This commit is contained in:
Lysandre Debut
2020-05-05 10:23:01 -04:00
committed by GitHub
parent 818463ee8e
commit 79b1c6966b
2 changed files with 3 additions and 3 deletions

View File

@@ -45,7 +45,7 @@ def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key:
setattr(configs_no_init, key, 0.0)
setattr(configs_no_init, key, 1e-10)
return configs_no_init
@@ -96,7 +96,7 @@ class ModelTesterMixin:
for name, param in model.named_parameters():
if param.requires_grad:
self.assertIn(
param.data.mean().item(),
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
)