From 02173a1a0ae8b39990fea14edd815468fba7b8c8 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 15 Nov 2018 21:49:12 +0100 Subject: [PATCH] fixing error in isnan test for optimizer_on_cpu & fp16 --- run_squad.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/run_squad.py b/run_squad.py index 248b92c504..aa3eb48ccd 100644 --- a/run_squad.py +++ b/run_squad.py @@ -688,11 +688,14 @@ def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_n if name_opti != name_model: logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) raise ValueError - if test_nan and torch.isnan(param_model.grad).sum() > 0: - is_nan = True - if param_opti.grad is None: - param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size())) - param_opti.grad.data.copy_(param_model.grad.data) + if param_model.grad is not None: + if test_nan and torch.isnan(param_model.grad).sum() > 0: + is_nan = True + if param_opti.grad is None: + param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size())) + param_opti.grad.data.copy_(param_model.grad.data) + else: + param_opti.grad = None return is_nan def main():