Cleaned generate_text_pplm. Identical output as before.
This commit is contained in:
@@ -471,59 +471,49 @@ def generate_text_pplm(
|
|||||||
decay=False,
|
decay=False,
|
||||||
gamma=1.5,
|
gamma=1.5,
|
||||||
):
|
):
|
||||||
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
|
output_so_far = (
|
||||||
0) if context else None
|
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
|
||||||
|
if context
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# collect one hot vectors for bags of words
|
# collect one hot vectors for bags of words
|
||||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
|
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
|
||||||
|
|
||||||
grad_norms = None
|
grad_norms = None
|
||||||
|
unpert_discrim_loss = 0
|
||||||
loss_in_time = []
|
loss_in_time = []
|
||||||
for i in trange(length, ascii=True):
|
for i in trange(length, ascii=True):
|
||||||
|
|
||||||
# Get past/probs for current output, except for last word
|
# Get past/probs for current output, except for last word
|
||||||
# Note that GPT takes 2 inputs: past + current-token
|
# Note that GPT takes 2 inputs: past + current_token
|
||||||
# Therefore, use everything from before current i/p token to generate relevant past
|
|
||||||
|
|
||||||
if past is None and output is not None:
|
# run model forward to obtain unperturbed
|
||||||
prev = output[:, -1:]
|
if past is None and output_so_far is not None:
|
||||||
# _, past = model(output[:, :-1])
|
last = output_so_far[:, -1:]
|
||||||
# original_probs, true_past = model(output)
|
_, past, _ = model(output_so_far[:, :-1])
|
||||||
# true_hidden = model.hidden_states
|
|
||||||
|
|
||||||
# Piero modified model call
|
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
|
||||||
_, past, _ = model(output[:, :-1])
|
unpert_last_hidden = unpert_all_hidden[-1]
|
||||||
unpert_logits, unpert_past, unpert_all_hidden = model(output)
|
|
||||||
true_hidden = unpert_all_hidden[-1]
|
|
||||||
|
|
||||||
else:
|
|
||||||
# original_probs, true_past = model(output)
|
|
||||||
# true_hidden = model.hidden_states
|
|
||||||
|
|
||||||
# Piero modified model call
|
|
||||||
unpert_logits, unpert_past, unpert_all_hidden = model(output)
|
|
||||||
true_hidden = unpert_all_hidden[-1]
|
|
||||||
|
|
||||||
# Modify the past if necessary
|
|
||||||
|
|
||||||
|
# check if we are abowe grad max length
|
||||||
if i >= grad_length:
|
if i >= grad_length:
|
||||||
current_stepsize = stepsize * 0
|
current_stepsize = stepsize * 0
|
||||||
else:
|
else:
|
||||||
current_stepsize = stepsize
|
current_stepsize = stepsize
|
||||||
|
|
||||||
|
# modify the past if necessary
|
||||||
if not perturb or num_iterations == 0:
|
if not perturb or num_iterations == 0:
|
||||||
perturbed_past = past
|
pert_past = past
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Piero modified model call
|
accumulated_hidden = unpert_last_hidden[:, :-1, :]
|
||||||
# accumulated_hidden = model.hidden_states[:, :-1, :]
|
|
||||||
accumulated_hidden = true_hidden[:, :-1, :]
|
|
||||||
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
||||||
|
|
||||||
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
|
pert_past, _, grad_norms, loss_this_iter = perturb_past(
|
||||||
past,
|
past,
|
||||||
model,
|
model,
|
||||||
prev,
|
last,
|
||||||
unpert_past=unpert_past,
|
unpert_past=unpert_past,
|
||||||
unpert_logits=unpert_logits,
|
unpert_logits=unpert_logits,
|
||||||
accumulated_hidden=accumulated_hidden,
|
accumulated_hidden=accumulated_hidden,
|
||||||
@@ -540,68 +530,59 @@ def generate_text_pplm(
|
|||||||
decay=decay,
|
decay=decay,
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
)
|
)
|
||||||
loss_in_time.append(loss_per_iter)
|
loss_in_time.append(loss_this_iter)
|
||||||
|
|
||||||
# Piero modified model call
|
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
||||||
logits, past, pert_all_hidden = model(prev, past=perturbed_past)
|
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
||||||
# test_logits = F.softmax(test_logits[:, -1, :], dim=-1)
|
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||||
# likelywords = torch.topk(test_logits, k=10, dim=-1)
|
|
||||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
|
||||||
|
|
||||||
if classifier is not None:
|
if classifier is not None:
|
||||||
ce_loss = torch.nn.CrossEntropyLoss()
|
ce_loss = torch.nn.CrossEntropyLoss()
|
||||||
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
|
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||||
label = torch.tensor([label_class], device='cuda',
|
label = torch.tensor([label_class], device='cuda',
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
true_discrim_loss = ce_loss(predicted_sentiment, label)
|
unpert_discrim_loss = ce_loss(prediction, label)
|
||||||
print("true discrim loss", true_discrim_loss.data.cpu().numpy())
|
print(
|
||||||
|
"unperturbed discrim loss",
|
||||||
|
unpert_discrim_loss.data.cpu().numpy()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
true_discrim_loss = 0
|
unpert_discrim_loss = 0
|
||||||
|
|
||||||
# Piero modified model call
|
|
||||||
# hidden = model.hidden_states # update hidden
|
|
||||||
# logits = model.forward_hidden(hidden)
|
|
||||||
logits = logits[:, -1, :] / temperature # + SMALL_CONST
|
|
||||||
|
|
||||||
# logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST
|
|
||||||
|
|
||||||
log_probs = F.softmax(logits, dim=-1)
|
|
||||||
|
|
||||||
# Fuse the modified model and original model
|
# Fuse the modified model and original model
|
||||||
if perturb:
|
if perturb:
|
||||||
|
|
||||||
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST
|
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||||
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
|
||||||
# likelywords = torch.topk(original_probs, k=10, dim=-1)
|
|
||||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
|
||||||
|
|
||||||
log_probs = ((log_probs ** gm_scale) * (
|
pert_probs = ((pert_probs ** gm_scale) * (
|
||||||
unpert_logits ** (1 - gm_scale))) # + SMALL_CONST
|
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
|
||||||
|
pert_probs = top_k_filter(pert_probs, k=top_k,
|
||||||
log_probs = top_k_filter(log_probs, k=top_k,
|
|
||||||
probs=True) # + SMALL_CONST
|
probs=True) # + SMALL_CONST
|
||||||
|
|
||||||
if torch.sum(log_probs) <= 1:
|
# rescale
|
||||||
log_probs = log_probs / torch.sum(log_probs)
|
if torch.sum(pert_probs) <= 1:
|
||||||
|
pert_probs = pert_probs / torch.sum(pert_probs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logits = top_k_filter(logits, k=top_k) # + SMALL_CONST
|
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
|
||||||
log_probs = F.softmax(logits, dim=-1)
|
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||||
|
|
||||||
|
# sample or greedy
|
||||||
if sample:
|
if sample:
|
||||||
# likelywords = torch.topk(log_probs, k=args.top_k, dim=-1)
|
last = torch.multinomial(pert_probs, num_samples=1)
|
||||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
|
||||||
# print(likelywords[0].tolist())
|
|
||||||
prev = torch.multinomial(log_probs, num_samples=1)
|
|
||||||
else:
|
|
||||||
_, prev = torch.topk(log_probs, k=1, dim=-1)
|
|
||||||
# if perturb:
|
|
||||||
# prev = future
|
|
||||||
output = prev if output is None else torch.cat((output, prev),
|
|
||||||
dim=1) # update output
|
|
||||||
print(TOKENIZER.decode(output.tolist()[0]))
|
|
||||||
|
|
||||||
return output, true_discrim_loss, loss_in_time
|
else:
|
||||||
|
_, last = torch.topk(pert_probs, k=1, dim=-1)
|
||||||
|
|
||||||
|
# update context/output_so_far appending the new token
|
||||||
|
output_so_far = (
|
||||||
|
last if output_so_far is None
|
||||||
|
else torch.cat((output_so_far, last), dim=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(TOKENIZER.decode(output_so_far.tolist()[0]))
|
||||||
|
|
||||||
|
return output_so_far, unpert_discrim_loss, loss_in_time
|
||||||
|
|
||||||
|
|
||||||
def run_model():
|
def run_model():
|
||||||
|
|||||||
Reference in New Issue
Block a user