run_pplm.py bug fix (#4867)
`is_leaf` may become `False` after `.to(device=device)` function call.
This commit is contained in:
@@ -148,6 +148,9 @@ def perturb_past(
|
|||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
print("Iteration ", i + 1)
|
print("Iteration ", i + 1)
|
||||||
curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
|
curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
|
||||||
|
# make sure p_.grad is not None
|
||||||
|
for p_ in curr_perturbation:
|
||||||
|
p_.retain_grad()
|
||||||
|
|
||||||
# Compute hidden using perturbed past
|
# Compute hidden using perturbed past
|
||||||
perturbed_past = list(map(add, past, curr_perturbation))
|
perturbed_past = list(map(add, past, curr_perturbation))
|
||||||
|
|||||||
Reference in New Issue
Block a user