Fixed horizon_length for PPLM (#13886)
* fixed horizon_length * fixed horizon_length * fix style
This commit is contained in:
@@ -181,7 +181,14 @@ def perturb_past(
|
|||||||
for _ in range(horizon_length):
|
for _ in range(horizon_length):
|
||||||
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
||||||
lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds)
|
lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds)
|
||||||
curr_unpert_past, curr_all_hidden = lm_output["past_key_values"], lm_output["hidden_states"]
|
curr_all_logits, curr_unpert_past, curr_all_hidden = (
|
||||||
|
lm_output["logits"],
|
||||||
|
lm_output["past_key_values"],
|
||||||
|
lm_output["hidden_states"],
|
||||||
|
)
|
||||||
|
curr_logits = curr_all_logits[:, -1, :]
|
||||||
|
curr_probs = nn.functional.softmax(curr_logits, dim=-1)
|
||||||
|
curr_probs = torch.unsqueeze(curr_probs, dim=1)
|
||||||
curr_hidden = curr_all_hidden[-1]
|
curr_hidden = curr_all_hidden[-1]
|
||||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
|
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user