Fix for making unditioned generation work. Identical output as before.
This commit is contained in:
@@ -481,6 +481,7 @@ def generate_text_pplm(
|
||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
|
||||
|
||||
grad_norms = None
|
||||
last = None
|
||||
unpert_discrim_loss = 0
|
||||
loss_in_time = []
|
||||
for i in trange(length, ascii=True):
|
||||
@@ -491,7 +492,8 @@ def generate_text_pplm(
|
||||
# run model forward to obtain unperturbed
|
||||
if past is None and output_so_far is not None:
|
||||
last = output_so_far[:, -1:]
|
||||
_, past, _ = model(output_so_far[:, :-1])
|
||||
if output_so_far.shape[1] > 1:
|
||||
_, past, _ = model(output_so_far[:, :-1])
|
||||
|
||||
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
|
||||
unpert_last_hidden = unpert_all_hidden[-1]
|
||||
@@ -510,27 +512,30 @@ def generate_text_pplm(
|
||||
accumulated_hidden = unpert_last_hidden[:, :-1, :]
|
||||
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
||||
|
||||
pert_past, _, grad_norms, loss_this_iter = perturb_past(
|
||||
past,
|
||||
model,
|
||||
last,
|
||||
unpert_past=unpert_past,
|
||||
unpert_logits=unpert_logits,
|
||||
accumulated_hidden=accumulated_hidden,
|
||||
grad_norms=grad_norms,
|
||||
stepsize=current_stepsize,
|
||||
classifier=classifier,
|
||||
label_class=label_class,
|
||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||
loss_type=loss_type,
|
||||
num_iterations=num_iterations,
|
||||
kl_scale=kl_scale,
|
||||
window_length=window_length,
|
||||
horizon_length=horizon_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
)
|
||||
loss_in_time.append(loss_this_iter)
|
||||
if past is not None:
|
||||
pert_past, _, grad_norms, loss_this_iter = perturb_past(
|
||||
past,
|
||||
model,
|
||||
last,
|
||||
unpert_past=unpert_past,
|
||||
unpert_logits=unpert_logits,
|
||||
accumulated_hidden=accumulated_hidden,
|
||||
grad_norms=grad_norms,
|
||||
stepsize=current_stepsize,
|
||||
classifier=classifier,
|
||||
label_class=label_class,
|
||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||
loss_type=loss_type,
|
||||
num_iterations=num_iterations,
|
||||
kl_scale=kl_scale,
|
||||
window_length=window_length,
|
||||
horizon_length=horizon_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
)
|
||||
loss_in_time.append(loss_this_iter)
|
||||
else:
|
||||
pert_past = past
|
||||
|
||||
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
||||
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
||||
|
||||
Reference in New Issue
Block a user