From 79b1c6966b2f0d63269eacbe87fade530ee4f05c Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 5 May 2020 10:23:01 -0400 Subject: [PATCH] Pytorch 1.5.0 (#3973) * Standard deviation can no longer be set to 0 * Remove torch pinned version * 9th instead of 10th, silly me --- setup.py | 2 +- tests/test_modeling_common.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 17f231f851..a3d36ca124 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ extras["mecab"] = ["mecab-python3"] extras["sklearn"] = ["scikit-learn"] extras["tf"] = ["tensorflow"] extras["tf-cpu"] = ["tensorflow-cpu"] -extras["torch"] = ["torch==1.4.0"] +extras["torch"] = ["torch"] extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"] extras["all"] = extras["serving"] + ["tensorflow", "torch"] diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9fd1e0a23b..2909e17e9e 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -45,7 +45,7 @@ def _config_zero_init(config): configs_no_init = copy.deepcopy(config) for key in configs_no_init.__dict__.keys(): if "_range" in key or "_std" in key or "initializer_factor" in key: - setattr(configs_no_init, key, 0.0) + setattr(configs_no_init, key, 1e-10) return configs_no_init @@ -96,7 +96,7 @@ class ModelTesterMixin: for name, param in model.named_parameters(): if param.requires_grad: self.assertIn( - param.data.mean().item(), + ((param.data.mean() * 1e9).round() / 1e9).item(), [0.0, 1.0], msg="Parameter {} of model {} seems not properly initialized".format(name, model_class), )