Fix PPLM (#8779)
* Fix pplm * fix style * make style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -154,7 +154,8 @@ def perturb_past(
|
||||
# Compute hidden using perturbed past
|
||||
perturbed_past = list(map(add, past, curr_perturbation))
|
||||
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
||||
all_logits, _, all_hidden = model(last, past=perturbed_past)
|
||||
lm_output = model(last, past_key_values=perturbed_past)
|
||||
all_logits, all_hidden = lm_output["logits"], lm_output["hidden_states"]
|
||||
hidden = all_hidden[-1]
|
||||
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
|
||||
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||
@@ -179,7 +180,8 @@ def perturb_past(
|
||||
wte = model.resize_token_embeddings()
|
||||
for _ in range(horizon_length):
|
||||
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
||||
_, curr_unpert_past, curr_all_hidden = model(past=curr_unpert_past, inputs_embeds=inputs_embeds)
|
||||
lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds)
|
||||
curr_unpert_past, curr_all_hidden = lm_output["past_key_values"], lm_output["hidden_states"]
|
||||
curr_hidden = curr_all_hidden[-1]
|
||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
|
||||
|
||||
@@ -462,9 +464,14 @@ def generate_text_pplm(
|
||||
if past is None and output_so_far is not None:
|
||||
last = output_so_far[:, -1:]
|
||||
if output_so_far.shape[1] > 1:
|
||||
_, past, _ = model(output_so_far[:, :-1])
|
||||
past = model(output_so_far[:, :-1])["past_key_values"]
|
||||
|
||||
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
|
||||
lm_output = model(output_so_far)
|
||||
unpert_logits, unpert_past, unpert_all_hidden = (
|
||||
lm_output["logits"],
|
||||
lm_output["past_key_values"],
|
||||
lm_output["hidden_states"],
|
||||
)
|
||||
unpert_last_hidden = unpert_all_hidden[-1]
|
||||
|
||||
# check if we are abowe grad max length
|
||||
@@ -507,7 +514,11 @@ def generate_text_pplm(
|
||||
else:
|
||||
pert_past = past
|
||||
|
||||
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
||||
lm_output = model(last, past_key_values=pert_past)
|
||||
pert_logits, past = (
|
||||
lm_output["logits"],
|
||||
lm_output["past_key_values"],
|
||||
)
|
||||
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
||||
|
||||
for token_idx in set(output_so_far[0].tolist()):
|
||||
|
||||
@@ -64,7 +64,7 @@ class Discriminator(torch.nn.Module):
|
||||
|
||||
def avg_representation(self, x):
|
||||
mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
|
||||
hidden, _ = self.encoder.transformer(x)
|
||||
hidden = self.encoder.transformer(x)["last_hidden_state"]
|
||||
masked_hidden = hidden * mask
|
||||
avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
|
||||
return avg_hidden
|
||||
|
||||
Reference in New Issue
Block a user