From cba85a67b9c2897593321012f3c6a575545e49e2 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 15 Nov 2018 21:47:41 +0100 Subject: [PATCH] fix nan in optimizer_on_cpu --- examples/run_squad.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 1011e836fd..6f06fb5c64 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -687,11 +687,12 @@ def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_n if name_opti != name_model: logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) raise ValueError - if test_nan and torch.isnan(param_model.grad).sum() > 0: - is_nan = True - if param_opti.grad is None: - param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size())) - param_opti.grad.data.copy_(param_model.grad.data) + if param_model.grad is not None: + if test_nan and torch.isnan(param_model.grad).sum() > 0: + is_nan = True + if param_opti.grad is None: + param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size())) + param_opti.grad.data.copy_(param_model.grad.data) return is_nan def main():