From 52708d263737e91e539d074a8ca2acd705210dff Mon Sep 17 00:00:00 2001 From: chutaklee Date: Fri, 27 Nov 2020 05:23:36 +0800 Subject: [PATCH] Fix PPLM (#8779) * Fix pplm * fix style * make style Co-authored-by: Patrick von Platen --- examples/text-generation/pplm/run_pplm.py | 21 ++++++++++++++----- .../pplm/run_pplm_discrim_train.py | 2 +- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/text-generation/pplm/run_pplm.py b/examples/text-generation/pplm/run_pplm.py index 96aee8be06..8d605fac49 100644 --- a/examples/text-generation/pplm/run_pplm.py +++ b/examples/text-generation/pplm/run_pplm.py @@ -154,7 +154,8 @@ def perturb_past( # Compute hidden using perturbed past perturbed_past = list(map(add, past, curr_perturbation)) _, _, _, curr_length, _ = curr_perturbation[0].shape - all_logits, _, all_hidden = model(last, past=perturbed_past) + lm_output = model(last, past_key_values=perturbed_past) + all_logits, all_hidden = lm_output["logits"], lm_output["hidden_states"] hidden = all_hidden[-1] new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach() # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth) @@ -179,7 +180,8 @@ def perturb_past( wte = model.resize_token_embeddings() for _ in range(horizon_length): inputs_embeds = torch.matmul(curr_probs, wte.weight.data) - _, curr_unpert_past, curr_all_hidden = model(past=curr_unpert_past, inputs_embeds=inputs_embeds) + lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds) + curr_unpert_past, curr_all_hidden = lm_output["past_key_values"], lm_output["hidden_states"] curr_hidden = curr_all_hidden[-1] new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1) @@ -462,9 +464,14 @@ def generate_text_pplm( if past is None and output_so_far is not None: last = output_so_far[:, -1:] if output_so_far.shape[1] > 1: - _, past, _ = model(output_so_far[:, :-1]) + past = model(output_so_far[:, :-1])["past_key_values"] - unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) + lm_output = model(output_so_far) + unpert_logits, unpert_past, unpert_all_hidden = ( + lm_output["logits"], + lm_output["past_key_values"], + lm_output["hidden_states"], + ) unpert_last_hidden = unpert_all_hidden[-1] # check if we are abowe grad max length @@ -507,7 +514,11 @@ def generate_text_pplm( else: pert_past = past - pert_logits, past, pert_all_hidden = model(last, past=pert_past) + lm_output = model(last, past_key_values=pert_past) + pert_logits, past = ( + lm_output["logits"], + lm_output["past_key_values"], + ) pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST for token_idx in set(output_so_far[0].tolist()): diff --git a/examples/text-generation/pplm/run_pplm_discrim_train.py b/examples/text-generation/pplm/run_pplm_discrim_train.py index 306e519b52..51cdb56773 100644 --- a/examples/text-generation/pplm/run_pplm_discrim_train.py +++ b/examples/text-generation/pplm/run_pplm_discrim_train.py @@ -64,7 +64,7 @@ class Discriminator(torch.nn.Module): def avg_representation(self, x): mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach() - hidden, _ = self.encoder.transformer(x) + hidden = self.encoder.transformer(x)["last_hidden_state"] masked_hidden = hidden * mask avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON) return avg_hidden