From ffc29354051ccd4fa3fd12010abacb0ff2e0733e Mon Sep 17 00:00:00 2001 From: piero Date: Wed, 27 Nov 2019 18:30:42 -0800 Subject: [PATCH] Fix for making unditioned generation work. Identical output as before. --- examples/run_pplm.py | 49 ++++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/examples/run_pplm.py b/examples/run_pplm.py index bd03bbe5e0..e337add46d 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -481,6 +481,7 @@ def generate_text_pplm( one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices) grad_norms = None + last = None unpert_discrim_loss = 0 loss_in_time = [] for i in trange(length, ascii=True): @@ -491,7 +492,8 @@ def generate_text_pplm( # run model forward to obtain unperturbed if past is None and output_so_far is not None: last = output_so_far[:, -1:] - _, past, _ = model(output_so_far[:, :-1]) + if output_so_far.shape[1] > 1: + _, past, _ = model(output_so_far[:, :-1]) unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) unpert_last_hidden = unpert_all_hidden[-1] @@ -510,27 +512,30 @@ def generate_text_pplm( accumulated_hidden = unpert_last_hidden[:, :-1, :] accumulated_hidden = torch.sum(accumulated_hidden, dim=1) - pert_past, _, grad_norms, loss_this_iter = perturb_past( - past, - model, - last, - unpert_past=unpert_past, - unpert_logits=unpert_logits, - accumulated_hidden=accumulated_hidden, - grad_norms=grad_norms, - stepsize=current_stepsize, - classifier=classifier, - label_class=label_class, - one_hot_bows_vectors=one_hot_bows_vectors, - loss_type=loss_type, - num_iterations=num_iterations, - kl_scale=kl_scale, - window_length=window_length, - horizon_length=horizon_length, - decay=decay, - gamma=gamma, - ) - loss_in_time.append(loss_this_iter) + if past is not None: + pert_past, _, grad_norms, loss_this_iter = perturb_past( + past, + model, + last, + unpert_past=unpert_past, + unpert_logits=unpert_logits, + accumulated_hidden=accumulated_hidden, + grad_norms=grad_norms, + stepsize=current_stepsize, + classifier=classifier, + label_class=label_class, + one_hot_bows_vectors=one_hot_bows_vectors, + loss_type=loss_type, + num_iterations=num_iterations, + kl_scale=kl_scale, + window_length=window_length, + horizon_length=horizon_length, + decay=decay, + gamma=gamma, + ) + loss_in_time.append(loss_this_iter) + else: + pert_past = past pert_logits, past, pert_all_hidden = model(last, past=pert_past) pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST