Fix PPLM (#8779)
* Fix pplm * fix style * make style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -154,7 +154,8 @@ def perturb_past(
|
|||||||
# Compute hidden using perturbed past
|
# Compute hidden using perturbed past
|
||||||
perturbed_past = list(map(add, past, curr_perturbation))
|
perturbed_past = list(map(add, past, curr_perturbation))
|
||||||
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
||||||
all_logits, _, all_hidden = model(last, past=perturbed_past)
|
lm_output = model(last, past_key_values=perturbed_past)
|
||||||
|
all_logits, all_hidden = lm_output["logits"], lm_output["hidden_states"]
|
||||||
hidden = all_hidden[-1]
|
hidden = all_hidden[-1]
|
||||||
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
|
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
|
||||||
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||||
@@ -179,7 +180,8 @@ def perturb_past(
|
|||||||
wte = model.resize_token_embeddings()
|
wte = model.resize_token_embeddings()
|
||||||
for _ in range(horizon_length):
|
for _ in range(horizon_length):
|
||||||
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
||||||
_, curr_unpert_past, curr_all_hidden = model(past=curr_unpert_past, inputs_embeds=inputs_embeds)
|
lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds)
|
||||||
|
curr_unpert_past, curr_all_hidden = lm_output["past_key_values"], lm_output["hidden_states"]
|
||||||
curr_hidden = curr_all_hidden[-1]
|
curr_hidden = curr_all_hidden[-1]
|
||||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
|
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
|
||||||
|
|
||||||
@@ -462,9 +464,14 @@ def generate_text_pplm(
|
|||||||
if past is None and output_so_far is not None:
|
if past is None and output_so_far is not None:
|
||||||
last = output_so_far[:, -1:]
|
last = output_so_far[:, -1:]
|
||||||
if output_so_far.shape[1] > 1:
|
if output_so_far.shape[1] > 1:
|
||||||
_, past, _ = model(output_so_far[:, :-1])
|
past = model(output_so_far[:, :-1])["past_key_values"]
|
||||||
|
|
||||||
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
|
lm_output = model(output_so_far)
|
||||||
|
unpert_logits, unpert_past, unpert_all_hidden = (
|
||||||
|
lm_output["logits"],
|
||||||
|
lm_output["past_key_values"],
|
||||||
|
lm_output["hidden_states"],
|
||||||
|
)
|
||||||
unpert_last_hidden = unpert_all_hidden[-1]
|
unpert_last_hidden = unpert_all_hidden[-1]
|
||||||
|
|
||||||
# check if we are abowe grad max length
|
# check if we are abowe grad max length
|
||||||
@@ -507,7 +514,11 @@ def generate_text_pplm(
|
|||||||
else:
|
else:
|
||||||
pert_past = past
|
pert_past = past
|
||||||
|
|
||||||
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
lm_output = model(last, past_key_values=pert_past)
|
||||||
|
pert_logits, past = (
|
||||||
|
lm_output["logits"],
|
||||||
|
lm_output["past_key_values"],
|
||||||
|
)
|
||||||
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
||||||
|
|
||||||
for token_idx in set(output_so_far[0].tolist()):
|
for token_idx in set(output_so_far[0].tolist()):
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ class Discriminator(torch.nn.Module):
|
|||||||
|
|
||||||
def avg_representation(self, x):
|
def avg_representation(self, x):
|
||||||
mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
|
mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
|
||||||
hidden, _ = self.encoder.transformer(x)
|
hidden = self.encoder.transformer(x)["last_hidden_state"]
|
||||||
masked_hidden = hidden * mask
|
masked_hidden = hidden * mask
|
||||||
avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
|
avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
|
||||||
return avg_hidden
|
return avg_hidden
|
||||||
|
|||||||
Reference in New Issue
Block a user