Cleaned perturb_past. Identical output as before.
This commit is contained in:
@@ -112,7 +112,7 @@ def top_k_filter(logits, k, probs=False):
|
|||||||
def perturb_past(
|
def perturb_past(
|
||||||
past,
|
past,
|
||||||
model,
|
model,
|
||||||
prev,
|
last,
|
||||||
unpert_past=None,
|
unpert_past=None,
|
||||||
unpert_logits=None,
|
unpert_logits=None,
|
||||||
accumulated_hidden=None,
|
accumulated_hidden=None,
|
||||||
@@ -128,156 +128,174 @@ def perturb_past(
|
|||||||
horizon_length=1,
|
horizon_length=1,
|
||||||
decay=False,
|
decay=False,
|
||||||
gamma=1.5,
|
gamma=1.5,
|
||||||
|
device='cuda'
|
||||||
):
|
):
|
||||||
# Generate inital perturbed past
|
# Generate inital perturbed past
|
||||||
past_perturb_orig = [
|
grad_accumulator = [
|
||||||
(np.random.uniform(0.0, 0.0, p.shape).astype('float32'))
|
(np.zeros(p.shape).astype("float32"))
|
||||||
for p in past]
|
for p in past
|
||||||
|
]
|
||||||
|
|
||||||
if accumulated_hidden is None:
|
if accumulated_hidden is None:
|
||||||
accumulated_hidden = 0
|
accumulated_hidden = 0
|
||||||
|
|
||||||
if decay:
|
if decay:
|
||||||
decay_mask = torch.arange(0., 1.0 + SMALL_CONST, 1.0 / (window_length))[
|
decay_mask = torch.arange(
|
||||||
1:]
|
0.,
|
||||||
|
1.0 + SMALL_CONST,
|
||||||
|
1.0 / (window_length)
|
||||||
|
)[1:]
|
||||||
else:
|
else:
|
||||||
decay_mask = 1.0
|
decay_mask = 1.0
|
||||||
|
|
||||||
|
# TODO fix this comment (SUMANTH)
|
||||||
# Generate a mask is gradient perturbated is based on a past window
|
# Generate a mask is gradient perturbated is based on a past window
|
||||||
_, _, _, current_length, _ = past[0].shape
|
_, _, _, curr_length, _ = past[0].shape
|
||||||
|
|
||||||
if current_length > window_length and window_length > 0:
|
if curr_length > window_length and window_length > 0:
|
||||||
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple(
|
ones_key_val_shape = (
|
||||||
[window_length]) + tuple(
|
tuple(past[0].shape[:-2])
|
||||||
past[0].shape[-1:])
|
+ tuple([window_length])
|
||||||
|
+ tuple(past[0].shape[-1:])
|
||||||
|
)
|
||||||
|
|
||||||
zeros_key_val_shape = tuple(past[0].shape[:-2]) + tuple(
|
zeros_key_val_shape = (
|
||||||
[current_length - window_length]) + tuple(
|
tuple(past[0].shape[:-2])
|
||||||
past[0].shape[-1:])
|
+ tuple([curr_length - window_length])
|
||||||
|
+ tuple(past[0].shape[-1:])
|
||||||
|
)
|
||||||
|
|
||||||
ones_mask = torch.ones(ones_key_val_shape)
|
ones_mask = torch.ones(ones_key_val_shape)
|
||||||
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
|
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
|
||||||
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
|
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
|
||||||
|
|
||||||
window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)),
|
window_mask = torch.cat(
|
||||||
dim=-2).cuda()
|
(ones_mask, torch.zeros(zeros_key_val_shape)),
|
||||||
|
dim=-2
|
||||||
|
).to(device)
|
||||||
else:
|
else:
|
||||||
window_mask = torch.ones_like(past[0]).cuda()
|
window_mask = torch.ones_like(past[0]).to(device)
|
||||||
|
|
||||||
|
# accumulate perturbations for num_iterations
|
||||||
loss_per_iter = []
|
loss_per_iter = []
|
||||||
|
new_accumulated_hidden = None
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
print("Iteration ", i + 1)
|
print("Iteration ", i + 1)
|
||||||
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
|
curr_perturbation = [
|
||||||
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
|
to_var(torch.from_numpy(p_), requires_grad=True)
|
||||||
|
for p_ in grad_accumulator
|
||||||
|
]
|
||||||
|
|
||||||
perturbed_past = list(map(add, past, past_perturb))
|
# Compute hidden using perturbed past
|
||||||
|
perturbed_past = list(map(add, past, curr_perturbation))
|
||||||
_, _, _, current_length, _ = past_perturb[0].shape
|
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
||||||
|
all_logits, _, all_hidden = model(last, past=perturbed_past)
|
||||||
# _, future_past = model(prev, past=perturbed_past)
|
|
||||||
# hidden = model.hidden_states
|
|
||||||
|
|
||||||
# Piero modified model call
|
|
||||||
logits, _, all_hidden = model(prev, past=perturbed_past)
|
|
||||||
hidden = all_hidden[-1]
|
hidden = all_hidden[-1]
|
||||||
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden,
|
new_accumulated_hidden = accumulated_hidden + torch.sum(
|
||||||
dim=1).detach()
|
hidden,
|
||||||
|
dim=1
|
||||||
|
).detach()
|
||||||
|
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||||
|
logits = all_logits[:, -1, :]
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
|
||||||
# TODO: Check the layer-norm consistency of this with trained discriminator
|
|
||||||
logits = logits[:, -1, :]
|
|
||||||
probabs = F.softmax(logits, dim=-1)
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
loss_list = []
|
loss_list = []
|
||||||
if loss_type == 1 or loss_type == 3:
|
if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
|
||||||
for one_hot_good in one_hot_bows_vectors:
|
for one_hot_bow in one_hot_bows_vectors:
|
||||||
good_logits = torch.mm(probabs, torch.t(one_hot_good))
|
bow_logits = torch.mm(probs, torch.t(one_hot_bow))
|
||||||
loss_word = good_logits
|
bow_loss = -torch.log(torch.sum(bow_logits))
|
||||||
loss_word = torch.sum(loss_word)
|
loss += bow_loss
|
||||||
loss_word = -torch.log(loss_word)
|
loss_list.append(bow_loss)
|
||||||
# loss_word = torch.sum(loss_word) /torch.sum(one_hot_good)
|
|
||||||
loss += loss_word
|
|
||||||
loss_list.append(loss_word)
|
|
||||||
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
||||||
|
|
||||||
if loss_type == 2 or loss_type == 3:
|
if loss_type == 2 or loss_type == 3:
|
||||||
ce_loss = torch.nn.CrossEntropyLoss()
|
ce_loss = torch.nn.CrossEntropyLoss()
|
||||||
new_true_past = unpert_past
|
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
|
||||||
for i in range(horizon_length):
|
curr_unpert_past = unpert_past
|
||||||
future_probabs = F.softmax(logits, dim=-1) # Get softmax
|
curr_probs = torch.unsqueeze(probs, dim=1)
|
||||||
future_probabs = torch.unsqueeze(future_probabs, dim=1)
|
wte = model.resize_token_embeddings()
|
||||||
|
for _ in range(horizon_length):
|
||||||
# _, new_true_past = model(future_probabs, past=new_true_past)
|
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
||||||
# future_hidden = model.hidden_states # Get expected hidden states
|
_, curr_unpert_past, curr_all_hidden = model(
|
||||||
|
past=curr_unpert_past,
|
||||||
# Piero modified model call
|
|
||||||
wte = model.resize_token_embeddings()
|
|
||||||
inputs_embeds = torch.matmul(future_probabs, wte.weight.data)
|
|
||||||
_, new_true_past, future_hidden = model(
|
|
||||||
past=new_true_past,
|
|
||||||
inputs_embeds=inputs_embeds
|
inputs_embeds=inputs_embeds
|
||||||
)
|
)
|
||||||
future_hidden = future_hidden[-1]
|
curr_hidden = curr_all_hidden[-1]
|
||||||
|
|
||||||
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
|
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
|
||||||
future_hidden, dim=1)
|
curr_hidden, dim=1)
|
||||||
|
|
||||||
predicted_sentiment = classifier(new_accumulated_hidden / (
|
prediction = classifier(new_accumulated_hidden /
|
||||||
current_length + 1 + horizon_length))
|
(curr_length + 1 + horizon_length))
|
||||||
|
|
||||||
label = torch.tensor([label_class], device='cuda',
|
label = torch.tensor([label_class], device=device,
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
discrim_loss = ce_loss(predicted_sentiment, label)
|
discrim_loss = ce_loss(prediction, label)
|
||||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||||
loss += discrim_loss
|
loss += discrim_loss
|
||||||
loss_list.append(discrim_loss)
|
loss_list.append(discrim_loss)
|
||||||
|
|
||||||
kl_loss = 0.0
|
kl_loss = 0.0
|
||||||
if kl_scale > 0.0:
|
if kl_scale > 0.0:
|
||||||
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
|
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||||
p = p + SMALL_CONST * (p <= SMALL_CONST).type(
|
unpert_probs = (
|
||||||
torch.FloatTensor).cuda().detach()
|
unpert_probs + SMALL_CONST *
|
||||||
correction = SMALL_CONST * (probabs <= SMALL_CONST).type(
|
(unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||||
torch.FloatTensor).cuda().detach()
|
)
|
||||||
corrected_probabs = probabs + correction.detach()
|
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
||||||
|
corrected_probs = probs + correction.detach()
|
||||||
kl_loss = kl_scale * (
|
kl_loss = kl_scale * (
|
||||||
(corrected_probabs * (corrected_probabs / p).log()).sum())
|
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
||||||
print(' kl_loss', (kl_loss).data.cpu().numpy())
|
)
|
||||||
loss += kl_loss # + discrim_loss
|
print(' kl_loss', kl_loss.data.cpu().numpy())
|
||||||
|
loss += kl_loss
|
||||||
|
|
||||||
loss_per_iter.append(loss.data.cpu().numpy())
|
loss_per_iter.append(loss.data.cpu().numpy())
|
||||||
|
|
||||||
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
||||||
|
|
||||||
|
# compute gradients
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if grad_norms is not None and loss_type == 1:
|
|
||||||
|
# calculate gradient norms
|
||||||
|
if grad_norms is not None and loss_type == PPLM_BOW:
|
||||||
grad_norms = [
|
grad_norms = [
|
||||||
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
|
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
|
||||||
for index, p_ in
|
for index, p_ in enumerate(curr_perturbation)
|
||||||
enumerate(past_perturb)]
|
]
|
||||||
else:
|
else:
|
||||||
grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST) for
|
grad_norms = [
|
||||||
index, p_ in enumerate(past_perturb)]
|
(torch.norm(p_.grad * window_mask) + SMALL_CONST)
|
||||||
|
for index, p_ in enumerate(curr_perturbation)
|
||||||
|
]
|
||||||
|
|
||||||
|
# normalize gradients
|
||||||
grad = [
|
grad = [
|
||||||
-stepsize * (p_.grad * window_mask / grad_norms[
|
-stepsize *
|
||||||
index] ** gamma).data.cpu().numpy()
|
(p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
|
||||||
for index, p_ in enumerate(past_perturb)]
|
for index, p_ in enumerate(curr_perturbation)
|
||||||
past_perturb_orig = list(map(add, grad, past_perturb_orig))
|
]
|
||||||
|
|
||||||
for p_ in past_perturb:
|
# accumulate gradient
|
||||||
|
grad_accumulator = list(map(add, grad, grad_accumulator))
|
||||||
|
|
||||||
|
# reset gradients, just to make sure
|
||||||
|
for p_ in curr_perturbation:
|
||||||
p_.grad.data.zero_()
|
p_.grad.data.zero_()
|
||||||
|
|
||||||
|
# removing past from the graph
|
||||||
new_past = []
|
new_past = []
|
||||||
for p in past:
|
for p_ in past:
|
||||||
new_past.append(p.detach())
|
new_past.append(p_.detach())
|
||||||
|
|
||||||
past = new_past
|
past = new_past
|
||||||
|
|
||||||
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
|
# apply the accumulated perturbations to the past
|
||||||
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
|
grad_accumulator = [
|
||||||
perturbed_past = list(map(add, past, past_perturb))
|
to_var(torch.from_numpy(p_), requires_grad=True)
|
||||||
|
for p_ in grad_accumulator
|
||||||
|
]
|
||||||
|
pert_past = list(map(add, past, grad_accumulator))
|
||||||
|
|
||||||
return perturbed_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
||||||
|
|
||||||
|
|
||||||
def get_classifier(
|
def get_classifier(
|
||||||
@@ -532,6 +550,7 @@ def generate_text_pplm(
|
|||||||
horizon_length=horizon_length,
|
horizon_length=horizon_length,
|
||||||
decay=decay,
|
decay=decay,
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
|
device=device
|
||||||
)
|
)
|
||||||
loss_in_time.append(loss_this_iter)
|
loss_in_time.append(loss_this_iter)
|
||||||
else:
|
else:
|
||||||
@@ -562,7 +581,7 @@ def generate_text_pplm(
|
|||||||
pert_probs = ((pert_probs ** gm_scale) * (
|
pert_probs = ((pert_probs ** gm_scale) * (
|
||||||
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
|
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
|
||||||
pert_probs = top_k_filter(pert_probs, k=top_k,
|
pert_probs = top_k_filter(pert_probs, k=top_k,
|
||||||
probs=True) # + SMALL_CONST
|
probs=True) # + SMALL_CONST
|
||||||
|
|
||||||
# rescale
|
# rescale
|
||||||
if torch.sum(pert_probs) <= 1:
|
if torch.sum(pert_probs) <= 1:
|
||||||
@@ -662,7 +681,8 @@ def run_model():
|
|||||||
parser.add_argument("--decay", action="store_true",
|
parser.add_argument("--decay", action="store_true",
|
||||||
help="whether to decay or not")
|
help="whether to decay or not")
|
||||||
parser.add_argument("--gamma", type=float, default=1.5)
|
parser.add_argument("--gamma", type=float, default=1.5)
|
||||||
parser.add_argument("--colorama", action="store_true", help="colors keywords")
|
parser.add_argument("--colorama", action="store_true",
|
||||||
|
help="colors keywords")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user