Cleaned perturb_past. Identical output as before.
This commit is contained in:
@@ -112,7 +112,7 @@ def top_k_filter(logits, k, probs=False):
|
||||
def perturb_past(
|
||||
past,
|
||||
model,
|
||||
prev,
|
||||
last,
|
||||
unpert_past=None,
|
||||
unpert_logits=None,
|
||||
accumulated_hidden=None,
|
||||
@@ -128,156 +128,174 @@ def perturb_past(
|
||||
horizon_length=1,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
device='cuda'
|
||||
):
|
||||
# Generate inital perturbed past
|
||||
past_perturb_orig = [
|
||||
(np.random.uniform(0.0, 0.0, p.shape).astype('float32'))
|
||||
for p in past]
|
||||
grad_accumulator = [
|
||||
(np.zeros(p.shape).astype("float32"))
|
||||
for p in past
|
||||
]
|
||||
|
||||
if accumulated_hidden is None:
|
||||
accumulated_hidden = 0
|
||||
|
||||
if decay:
|
||||
decay_mask = torch.arange(0., 1.0 + SMALL_CONST, 1.0 / (window_length))[
|
||||
1:]
|
||||
decay_mask = torch.arange(
|
||||
0.,
|
||||
1.0 + SMALL_CONST,
|
||||
1.0 / (window_length)
|
||||
)[1:]
|
||||
else:
|
||||
decay_mask = 1.0
|
||||
|
||||
# TODO fix this comment (SUMANTH)
|
||||
# Generate a mask is gradient perturbated is based on a past window
|
||||
_, _, _, current_length, _ = past[0].shape
|
||||
_, _, _, curr_length, _ = past[0].shape
|
||||
|
||||
if current_length > window_length and window_length > 0:
|
||||
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple(
|
||||
[window_length]) + tuple(
|
||||
past[0].shape[-1:])
|
||||
if curr_length > window_length and window_length > 0:
|
||||
ones_key_val_shape = (
|
||||
tuple(past[0].shape[:-2])
|
||||
+ tuple([window_length])
|
||||
+ tuple(past[0].shape[-1:])
|
||||
)
|
||||
|
||||
zeros_key_val_shape = tuple(past[0].shape[:-2]) + tuple(
|
||||
[current_length - window_length]) + tuple(
|
||||
past[0].shape[-1:])
|
||||
zeros_key_val_shape = (
|
||||
tuple(past[0].shape[:-2])
|
||||
+ tuple([curr_length - window_length])
|
||||
+ tuple(past[0].shape[-1:])
|
||||
)
|
||||
|
||||
ones_mask = torch.ones(ones_key_val_shape)
|
||||
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
|
||||
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
|
||||
|
||||
window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)),
|
||||
dim=-2).cuda()
|
||||
window_mask = torch.cat(
|
||||
(ones_mask, torch.zeros(zeros_key_val_shape)),
|
||||
dim=-2
|
||||
).to(device)
|
||||
else:
|
||||
window_mask = torch.ones_like(past[0]).cuda()
|
||||
window_mask = torch.ones_like(past[0]).to(device)
|
||||
|
||||
# accumulate perturbations for num_iterations
|
||||
loss_per_iter = []
|
||||
new_accumulated_hidden = None
|
||||
for i in range(num_iterations):
|
||||
print("Iteration ", i + 1)
|
||||
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
|
||||
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
|
||||
curr_perturbation = [
|
||||
to_var(torch.from_numpy(p_), requires_grad=True)
|
||||
for p_ in grad_accumulator
|
||||
]
|
||||
|
||||
perturbed_past = list(map(add, past, past_perturb))
|
||||
|
||||
_, _, _, current_length, _ = past_perturb[0].shape
|
||||
|
||||
# _, future_past = model(prev, past=perturbed_past)
|
||||
# hidden = model.hidden_states
|
||||
|
||||
# Piero modified model call
|
||||
logits, _, all_hidden = model(prev, past=perturbed_past)
|
||||
# Compute hidden using perturbed past
|
||||
perturbed_past = list(map(add, past, curr_perturbation))
|
||||
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
||||
all_logits, _, all_hidden = model(last, past=perturbed_past)
|
||||
hidden = all_hidden[-1]
|
||||
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden,
|
||||
dim=1).detach()
|
||||
new_accumulated_hidden = accumulated_hidden + torch.sum(
|
||||
hidden,
|
||||
dim=1
|
||||
).detach()
|
||||
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||
logits = all_logits[:, -1, :]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
# TODO: Check the layer-norm consistency of this with trained discriminator
|
||||
logits = logits[:, -1, :]
|
||||
probabs = F.softmax(logits, dim=-1)
|
||||
loss = 0.0
|
||||
loss_list = []
|
||||
if loss_type == 1 or loss_type == 3:
|
||||
for one_hot_good in one_hot_bows_vectors:
|
||||
good_logits = torch.mm(probabs, torch.t(one_hot_good))
|
||||
loss_word = good_logits
|
||||
loss_word = torch.sum(loss_word)
|
||||
loss_word = -torch.log(loss_word)
|
||||
# loss_word = torch.sum(loss_word) /torch.sum(one_hot_good)
|
||||
loss += loss_word
|
||||
loss_list.append(loss_word)
|
||||
if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
|
||||
for one_hot_bow in one_hot_bows_vectors:
|
||||
bow_logits = torch.mm(probs, torch.t(one_hot_bow))
|
||||
bow_loss = -torch.log(torch.sum(bow_logits))
|
||||
loss += bow_loss
|
||||
loss_list.append(bow_loss)
|
||||
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
||||
|
||||
if loss_type == 2 or loss_type == 3:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
new_true_past = unpert_past
|
||||
for i in range(horizon_length):
|
||||
future_probabs = F.softmax(logits, dim=-1) # Get softmax
|
||||
future_probabs = torch.unsqueeze(future_probabs, dim=1)
|
||||
|
||||
# _, new_true_past = model(future_probabs, past=new_true_past)
|
||||
# future_hidden = model.hidden_states # Get expected hidden states
|
||||
|
||||
# Piero modified model call
|
||||
wte = model.resize_token_embeddings()
|
||||
inputs_embeds = torch.matmul(future_probabs, wte.weight.data)
|
||||
_, new_true_past, future_hidden = model(
|
||||
past=new_true_past,
|
||||
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
|
||||
curr_unpert_past = unpert_past
|
||||
curr_probs = torch.unsqueeze(probs, dim=1)
|
||||
wte = model.resize_token_embeddings()
|
||||
for _ in range(horizon_length):
|
||||
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
||||
_, curr_unpert_past, curr_all_hidden = model(
|
||||
past=curr_unpert_past,
|
||||
inputs_embeds=inputs_embeds
|
||||
)
|
||||
future_hidden = future_hidden[-1]
|
||||
|
||||
curr_hidden = curr_all_hidden[-1]
|
||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
|
||||
future_hidden, dim=1)
|
||||
curr_hidden, dim=1)
|
||||
|
||||
predicted_sentiment = classifier(new_accumulated_hidden / (
|
||||
current_length + 1 + horizon_length))
|
||||
prediction = classifier(new_accumulated_hidden /
|
||||
(curr_length + 1 + horizon_length))
|
||||
|
||||
label = torch.tensor([label_class], device='cuda',
|
||||
label = torch.tensor([label_class], device=device,
|
||||
dtype=torch.long)
|
||||
discrim_loss = ce_loss(predicted_sentiment, label)
|
||||
discrim_loss = ce_loss(prediction, label)
|
||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||
loss += discrim_loss
|
||||
loss_list.append(discrim_loss)
|
||||
|
||||
kl_loss = 0.0
|
||||
if kl_scale > 0.0:
|
||||
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
|
||||
p = p + SMALL_CONST * (p <= SMALL_CONST).type(
|
||||
torch.FloatTensor).cuda().detach()
|
||||
correction = SMALL_CONST * (probabs <= SMALL_CONST).type(
|
||||
torch.FloatTensor).cuda().detach()
|
||||
corrected_probabs = probabs + correction.detach()
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = (
|
||||
unpert_probs + SMALL_CONST *
|
||||
(unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||
)
|
||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
||||
corrected_probs = probs + correction.detach()
|
||||
kl_loss = kl_scale * (
|
||||
(corrected_probabs * (corrected_probabs / p).log()).sum())
|
||||
print(' kl_loss', (kl_loss).data.cpu().numpy())
|
||||
loss += kl_loss # + discrim_loss
|
||||
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
||||
)
|
||||
print(' kl_loss', kl_loss.data.cpu().numpy())
|
||||
loss += kl_loss
|
||||
|
||||
loss_per_iter.append(loss.data.cpu().numpy())
|
||||
|
||||
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
||||
|
||||
# compute gradients
|
||||
loss.backward()
|
||||
if grad_norms is not None and loss_type == 1:
|
||||
|
||||
# calculate gradient norms
|
||||
if grad_norms is not None and loss_type == PPLM_BOW:
|
||||
grad_norms = [
|
||||
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
|
||||
for index, p_ in
|
||||
enumerate(past_perturb)]
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
else:
|
||||
grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST) for
|
||||
index, p_ in enumerate(past_perturb)]
|
||||
grad_norms = [
|
||||
(torch.norm(p_.grad * window_mask) + SMALL_CONST)
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
|
||||
# normalize gradients
|
||||
grad = [
|
||||
-stepsize * (p_.grad * window_mask / grad_norms[
|
||||
index] ** gamma).data.cpu().numpy()
|
||||
for index, p_ in enumerate(past_perturb)]
|
||||
past_perturb_orig = list(map(add, grad, past_perturb_orig))
|
||||
-stepsize *
|
||||
(p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
|
||||
for p_ in past_perturb:
|
||||
# accumulate gradient
|
||||
grad_accumulator = list(map(add, grad, grad_accumulator))
|
||||
|
||||
# reset gradients, just to make sure
|
||||
for p_ in curr_perturbation:
|
||||
p_.grad.data.zero_()
|
||||
|
||||
# removing past from the graph
|
||||
new_past = []
|
||||
for p in past:
|
||||
new_past.append(p.detach())
|
||||
|
||||
for p_ in past:
|
||||
new_past.append(p_.detach())
|
||||
past = new_past
|
||||
|
||||
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
|
||||
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
|
||||
perturbed_past = list(map(add, past, past_perturb))
|
||||
# apply the accumulated perturbations to the past
|
||||
grad_accumulator = [
|
||||
to_var(torch.from_numpy(p_), requires_grad=True)
|
||||
for p_ in grad_accumulator
|
||||
]
|
||||
pert_past = list(map(add, past, grad_accumulator))
|
||||
|
||||
return perturbed_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
||||
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
||||
|
||||
|
||||
def get_classifier(
|
||||
@@ -532,6 +550,7 @@ def generate_text_pplm(
|
||||
horizon_length=horizon_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
device=device
|
||||
)
|
||||
loss_in_time.append(loss_this_iter)
|
||||
else:
|
||||
@@ -562,7 +581,7 @@ def generate_text_pplm(
|
||||
pert_probs = ((pert_probs ** gm_scale) * (
|
||||
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
|
||||
pert_probs = top_k_filter(pert_probs, k=top_k,
|
||||
probs=True) # + SMALL_CONST
|
||||
probs=True) # + SMALL_CONST
|
||||
|
||||
# rescale
|
||||
if torch.sum(pert_probs) <= 1:
|
||||
@@ -662,7 +681,8 @@ def run_model():
|
||||
parser.add_argument("--decay", action="store_true",
|
||||
help="whether to decay or not")
|
||||
parser.add_argument("--gamma", type=float, default=1.5)
|
||||
parser.add_argument("--colorama", action="store_true", help="colors keywords")
|
||||
parser.add_argument("--colorama", action="store_true",
|
||||
help="colors keywords")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user