From 29c36e9f3678702e5ffd3fe2f1c9f6c1d6672578 Mon Sep 17 00:00:00 2001 From: songyouwei Date: Wed, 10 Jun 2020 07:14:27 +0800 Subject: [PATCH] run_pplm.py bug fix (#4867) `is_leaf` may become `False` after `.to(device=device)` function call. --- examples/text-generation/pplm/run_pplm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/text-generation/pplm/run_pplm.py b/examples/text-generation/pplm/run_pplm.py index abdd28c9af..bfbb79e438 100644 --- a/examples/text-generation/pplm/run_pplm.py +++ b/examples/text-generation/pplm/run_pplm.py @@ -148,6 +148,9 @@ def perturb_past( for i in range(num_iterations): print("Iteration ", i + 1) curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator] + # make sure p_.grad is not None + for p_ in curr_perturbation: + p_.retain_grad() # Compute hidden using perturbed past perturbed_past = list(map(add, past, curr_perturbation))