Fix for making unditioned generation work. Identical output as before.

This commit is contained in:
piero
2019-11-27 18:30:42 -08:00
committed by Julien Chaumond
parent 9f693a0c48
commit ffc2935405

View File

@@ -481,6 +481,7 @@ def generate_text_pplm(
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
grad_norms = None
last = None
unpert_discrim_loss = 0
loss_in_time = []
for i in trange(length, ascii=True):
@@ -491,7 +492,8 @@ def generate_text_pplm(
# run model forward to obtain unperturbed
if past is None and output_so_far is not None:
last = output_so_far[:, -1:]
_, past, _ = model(output_so_far[:, :-1])
if output_so_far.shape[1] > 1:
_, past, _ = model(output_so_far[:, :-1])
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
unpert_last_hidden = unpert_all_hidden[-1]
@@ -510,27 +512,30 @@ def generate_text_pplm(
accumulated_hidden = unpert_last_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
pert_past, _, grad_norms, loss_this_iter = perturb_past(
past,
model,
last,
unpert_past=unpert_past,
unpert_logits=unpert_logits,
accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms,
stepsize=current_stepsize,
classifier=classifier,
label_class=label_class,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_this_iter)
if past is not None:
pert_past, _, grad_norms, loss_this_iter = perturb_past(
past,
model,
last,
unpert_past=unpert_past,
unpert_logits=unpert_logits,
accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms,
stepsize=current_stepsize,
classifier=classifier,
label_class=label_class,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_this_iter)
else:
pert_past = past
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST