Fix for making unditioned generation work. Identical output as before.
This commit is contained in:
@@ -481,6 +481,7 @@ def generate_text_pplm(
|
|||||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
|
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
|
||||||
|
|
||||||
grad_norms = None
|
grad_norms = None
|
||||||
|
last = None
|
||||||
unpert_discrim_loss = 0
|
unpert_discrim_loss = 0
|
||||||
loss_in_time = []
|
loss_in_time = []
|
||||||
for i in trange(length, ascii=True):
|
for i in trange(length, ascii=True):
|
||||||
@@ -491,6 +492,7 @@ def generate_text_pplm(
|
|||||||
# run model forward to obtain unperturbed
|
# run model forward to obtain unperturbed
|
||||||
if past is None and output_so_far is not None:
|
if past is None and output_so_far is not None:
|
||||||
last = output_so_far[:, -1:]
|
last = output_so_far[:, -1:]
|
||||||
|
if output_so_far.shape[1] > 1:
|
||||||
_, past, _ = model(output_so_far[:, :-1])
|
_, past, _ = model(output_so_far[:, :-1])
|
||||||
|
|
||||||
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
|
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
|
||||||
@@ -510,6 +512,7 @@ def generate_text_pplm(
|
|||||||
accumulated_hidden = unpert_last_hidden[:, :-1, :]
|
accumulated_hidden = unpert_last_hidden[:, :-1, :]
|
||||||
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
||||||
|
|
||||||
|
if past is not None:
|
||||||
pert_past, _, grad_norms, loss_this_iter = perturb_past(
|
pert_past, _, grad_norms, loss_this_iter = perturb_past(
|
||||||
past,
|
past,
|
||||||
model,
|
model,
|
||||||
@@ -531,6 +534,8 @@ def generate_text_pplm(
|
|||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
)
|
)
|
||||||
loss_in_time.append(loss_this_iter)
|
loss_in_time.append(loss_this_iter)
|
||||||
|
else:
|
||||||
|
pert_past = past
|
||||||
|
|
||||||
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
||||||
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
||||||
|
|||||||
Reference in New Issue
Block a user