Cleaned generate_text_pplm. Identical output as before.
This commit is contained in:
@@ -471,59 +471,49 @@ def generate_text_pplm(
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
):
|
||||
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
|
||||
0) if context else None
|
||||
output_so_far = (
|
||||
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
|
||||
if context
|
||||
else None
|
||||
)
|
||||
|
||||
# collect one hot vectors for bags of words
|
||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
|
||||
|
||||
grad_norms = None
|
||||
unpert_discrim_loss = 0
|
||||
loss_in_time = []
|
||||
for i in trange(length, ascii=True):
|
||||
|
||||
# Get past/probs for current output, except for last word
|
||||
# Note that GPT takes 2 inputs: past + current-token
|
||||
# Therefore, use everything from before current i/p token to generate relevant past
|
||||
# Note that GPT takes 2 inputs: past + current_token
|
||||
|
||||
if past is None and output is not None:
|
||||
prev = output[:, -1:]
|
||||
# _, past = model(output[:, :-1])
|
||||
# original_probs, true_past = model(output)
|
||||
# true_hidden = model.hidden_states
|
||||
# run model forward to obtain unperturbed
|
||||
if past is None and output_so_far is not None:
|
||||
last = output_so_far[:, -1:]
|
||||
_, past, _ = model(output_so_far[:, :-1])
|
||||
|
||||
# Piero modified model call
|
||||
_, past, _ = model(output[:, :-1])
|
||||
unpert_logits, unpert_past, unpert_all_hidden = model(output)
|
||||
true_hidden = unpert_all_hidden[-1]
|
||||
|
||||
else:
|
||||
# original_probs, true_past = model(output)
|
||||
# true_hidden = model.hidden_states
|
||||
|
||||
# Piero modified model call
|
||||
unpert_logits, unpert_past, unpert_all_hidden = model(output)
|
||||
true_hidden = unpert_all_hidden[-1]
|
||||
|
||||
# Modify the past if necessary
|
||||
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
|
||||
unpert_last_hidden = unpert_all_hidden[-1]
|
||||
|
||||
# check if we are abowe grad max length
|
||||
if i >= grad_length:
|
||||
current_stepsize = stepsize * 0
|
||||
else:
|
||||
current_stepsize = stepsize
|
||||
|
||||
# modify the past if necessary
|
||||
if not perturb or num_iterations == 0:
|
||||
perturbed_past = past
|
||||
pert_past = past
|
||||
|
||||
else:
|
||||
# Piero modified model call
|
||||
# accumulated_hidden = model.hidden_states[:, :-1, :]
|
||||
accumulated_hidden = true_hidden[:, :-1, :]
|
||||
accumulated_hidden = unpert_last_hidden[:, :-1, :]
|
||||
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
||||
|
||||
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
|
||||
pert_past, _, grad_norms, loss_this_iter = perturb_past(
|
||||
past,
|
||||
model,
|
||||
prev,
|
||||
last,
|
||||
unpert_past=unpert_past,
|
||||
unpert_logits=unpert_logits,
|
||||
accumulated_hidden=accumulated_hidden,
|
||||
@@ -540,68 +530,59 @@ def generate_text_pplm(
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
)
|
||||
loss_in_time.append(loss_per_iter)
|
||||
loss_in_time.append(loss_this_iter)
|
||||
|
||||
# Piero modified model call
|
||||
logits, past, pert_all_hidden = model(prev, past=perturbed_past)
|
||||
# test_logits = F.softmax(test_logits[:, -1, :], dim=-1)
|
||||
# likelywords = torch.topk(test_logits, k=10, dim=-1)
|
||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
||||
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
||||
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
|
||||
if classifier is not None:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
|
||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||
label = torch.tensor([label_class], device='cuda',
|
||||
dtype=torch.long)
|
||||
true_discrim_loss = ce_loss(predicted_sentiment, label)
|
||||
print("true discrim loss", true_discrim_loss.data.cpu().numpy())
|
||||
unpert_discrim_loss = ce_loss(prediction, label)
|
||||
print(
|
||||
"unperturbed discrim loss",
|
||||
unpert_discrim_loss.data.cpu().numpy()
|
||||
)
|
||||
else:
|
||||
true_discrim_loss = 0
|
||||
|
||||
# Piero modified model call
|
||||
# hidden = model.hidden_states # update hidden
|
||||
# logits = model.forward_hidden(hidden)
|
||||
logits = logits[:, -1, :] / temperature # + SMALL_CONST
|
||||
|
||||
# logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST
|
||||
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
unpert_discrim_loss = 0
|
||||
|
||||
# Fuse the modified model and original model
|
||||
if perturb:
|
||||
|
||||
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST
|
||||
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
# likelywords = torch.topk(original_probs, k=10, dim=-1)
|
||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
|
||||
log_probs = ((log_probs ** gm_scale) * (
|
||||
unpert_logits ** (1 - gm_scale))) # + SMALL_CONST
|
||||
|
||||
log_probs = top_k_filter(log_probs, k=top_k,
|
||||
pert_probs = ((pert_probs ** gm_scale) * (
|
||||
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
|
||||
pert_probs = top_k_filter(pert_probs, k=top_k,
|
||||
probs=True) # + SMALL_CONST
|
||||
|
||||
if torch.sum(log_probs) <= 1:
|
||||
log_probs = log_probs / torch.sum(log_probs)
|
||||
# rescale
|
||||
if torch.sum(pert_probs) <= 1:
|
||||
pert_probs = pert_probs / torch.sum(pert_probs)
|
||||
|
||||
else:
|
||||
logits = top_k_filter(logits, k=top_k) # + SMALL_CONST
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
|
||||
# sample or greedy
|
||||
if sample:
|
||||
# likelywords = torch.topk(log_probs, k=args.top_k, dim=-1)
|
||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
||||
# print(likelywords[0].tolist())
|
||||
prev = torch.multinomial(log_probs, num_samples=1)
|
||||
else:
|
||||
_, prev = torch.topk(log_probs, k=1, dim=-1)
|
||||
# if perturb:
|
||||
# prev = future
|
||||
output = prev if output is None else torch.cat((output, prev),
|
||||
dim=1) # update output
|
||||
print(TOKENIZER.decode(output.tolist()[0]))
|
||||
last = torch.multinomial(pert_probs, num_samples=1)
|
||||
|
||||
return output, true_discrim_loss, loss_in_time
|
||||
else:
|
||||
_, last = torch.topk(pert_probs, k=1, dim=-1)
|
||||
|
||||
# update context/output_so_far appending the new token
|
||||
output_so_far = (
|
||||
last if output_so_far is None
|
||||
else torch.cat((output_so_far, last), dim=1)
|
||||
)
|
||||
|
||||
print(TOKENIZER.decode(output_so_far.tolist()[0]))
|
||||
|
||||
return output_so_far, unpert_discrim_loss, loss_in_time
|
||||
|
||||
|
||||
def run_model():
|
||||
|
||||
Reference in New Issue
Block a user