diff --git a/run_squad.py b/run_squad.py index 248b92c504..aa3eb48ccd 100644 --- a/run_squad.py +++ b/run_squad.py @@ -688,11 +688,14 @@ def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_n if name_opti != name_model: logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) raise ValueError - if test_nan and torch.isnan(param_model.grad).sum() > 0: - is_nan = True - if param_opti.grad is None: - param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size())) - param_opti.grad.data.copy_(param_model.grad.data) + if param_model.grad is not None: + if test_nan and torch.isnan(param_model.grad).sum() > 0: + is_nan = True + if param_opti.grad is None: + param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size())) + param_opti.grad.data.copy_(param_model.grad.data) + else: + param_opti.grad = None return is_nan def main():