diff --git a/examples/research_projects/pplm/run_pplm.py b/examples/research_projects/pplm/run_pplm.py index 4be4f01fd4..4872118433 100644 --- a/examples/research_projects/pplm/run_pplm.py +++ b/examples/research_projects/pplm/run_pplm.py @@ -181,7 +181,14 @@ def perturb_past( for _ in range(horizon_length): inputs_embeds = torch.matmul(curr_probs, wte.weight.data) lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds) - curr_unpert_past, curr_all_hidden = lm_output["past_key_values"], lm_output["hidden_states"] + curr_all_logits, curr_unpert_past, curr_all_hidden = ( + lm_output["logits"], + lm_output["past_key_values"], + lm_output["hidden_states"], + ) + curr_logits = curr_all_logits[:, -1, :] + curr_probs = nn.functional.softmax(curr_logits, dim=-1) + curr_probs = torch.unsqueeze(curr_probs, dim=1) curr_hidden = curr_all_hidden[-1] new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)