diff --git a/examples/run_pplm.py b/examples/run_pplm.py index e337add46d..77758759d9 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -112,7 +112,7 @@ def top_k_filter(logits, k, probs=False): def perturb_past( past, model, - prev, + last, unpert_past=None, unpert_logits=None, accumulated_hidden=None, @@ -128,156 +128,174 @@ def perturb_past( horizon_length=1, decay=False, gamma=1.5, + device='cuda' ): # Generate inital perturbed past - past_perturb_orig = [ - (np.random.uniform(0.0, 0.0, p.shape).astype('float32')) - for p in past] + grad_accumulator = [ + (np.zeros(p.shape).astype("float32")) + for p in past + ] if accumulated_hidden is None: accumulated_hidden = 0 if decay: - decay_mask = torch.arange(0., 1.0 + SMALL_CONST, 1.0 / (window_length))[ - 1:] + decay_mask = torch.arange( + 0., + 1.0 + SMALL_CONST, + 1.0 / (window_length) + )[1:] else: decay_mask = 1.0 + # TODO fix this comment (SUMANTH) # Generate a mask is gradient perturbated is based on a past window - _, _, _, current_length, _ = past[0].shape + _, _, _, curr_length, _ = past[0].shape - if current_length > window_length and window_length > 0: - ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple( - [window_length]) + tuple( - past[0].shape[-1:]) + if curr_length > window_length and window_length > 0: + ones_key_val_shape = ( + tuple(past[0].shape[:-2]) + + tuple([window_length]) + + tuple(past[0].shape[-1:]) + ) - zeros_key_val_shape = tuple(past[0].shape[:-2]) + tuple( - [current_length - window_length]) + tuple( - past[0].shape[-1:]) + zeros_key_val_shape = ( + tuple(past[0].shape[:-2]) + + tuple([curr_length - window_length]) + + tuple(past[0].shape[-1:]) + ) ones_mask = torch.ones(ones_key_val_shape) ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3) ones_mask = ones_mask.permute(0, 1, 2, 4, 3) - window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), - dim=-2).cuda() + window_mask = torch.cat( + (ones_mask, torch.zeros(zeros_key_val_shape)), + dim=-2 + ).to(device) else: - window_mask = torch.ones_like(past[0]).cuda() + window_mask = torch.ones_like(past[0]).to(device) + # accumulate perturbations for num_iterations loss_per_iter = [] + new_accumulated_hidden = None for i in range(num_iterations): print("Iteration ", i + 1) - past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] - past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] + curr_perturbation = [ + to_var(torch.from_numpy(p_), requires_grad=True) + for p_ in grad_accumulator + ] - perturbed_past = list(map(add, past, past_perturb)) - - _, _, _, current_length, _ = past_perturb[0].shape - - # _, future_past = model(prev, past=perturbed_past) - # hidden = model.hidden_states - - # Piero modified model call - logits, _, all_hidden = model(prev, past=perturbed_past) + # Compute hidden using perturbed past + perturbed_past = list(map(add, past, curr_perturbation)) + _, _, _, curr_length, _ = curr_perturbation[0].shape + all_logits, _, all_hidden = model(last, past=perturbed_past) hidden = all_hidden[-1] - new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, - dim=1).detach() + new_accumulated_hidden = accumulated_hidden + torch.sum( + hidden, + dim=1 + ).detach() + # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth) + logits = all_logits[:, -1, :] + probs = F.softmax(logits, dim=-1) - # TODO: Check the layer-norm consistency of this with trained discriminator - logits = logits[:, -1, :] - probabs = F.softmax(logits, dim=-1) loss = 0.0 loss_list = [] - if loss_type == 1 or loss_type == 3: - for one_hot_good in one_hot_bows_vectors: - good_logits = torch.mm(probabs, torch.t(one_hot_good)) - loss_word = good_logits - loss_word = torch.sum(loss_word) - loss_word = -torch.log(loss_word) - # loss_word = torch.sum(loss_word) /torch.sum(one_hot_good) - loss += loss_word - loss_list.append(loss_word) + if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM: + for one_hot_bow in one_hot_bows_vectors: + bow_logits = torch.mm(probs, torch.t(one_hot_bow)) + bow_loss = -torch.log(torch.sum(bow_logits)) + loss += bow_loss + loss_list.append(bow_loss) print(" pplm_bow_loss:", loss.data.cpu().numpy()) if loss_type == 2 or loss_type == 3: ce_loss = torch.nn.CrossEntropyLoss() - new_true_past = unpert_past - for i in range(horizon_length): - future_probabs = F.softmax(logits, dim=-1) # Get softmax - future_probabs = torch.unsqueeze(future_probabs, dim=1) - - # _, new_true_past = model(future_probabs, past=new_true_past) - # future_hidden = model.hidden_states # Get expected hidden states - - # Piero modified model call - wte = model.resize_token_embeddings() - inputs_embeds = torch.matmul(future_probabs, wte.weight.data) - _, new_true_past, future_hidden = model( - past=new_true_past, + # TODO why we need to do this assignment and not just using unpert_past? (Sumanth) + curr_unpert_past = unpert_past + curr_probs = torch.unsqueeze(probs, dim=1) + wte = model.resize_token_embeddings() + for _ in range(horizon_length): + inputs_embeds = torch.matmul(curr_probs, wte.weight.data) + _, curr_unpert_past, curr_all_hidden = model( + past=curr_unpert_past, inputs_embeds=inputs_embeds ) - future_hidden = future_hidden[-1] - + curr_hidden = curr_all_hidden[-1] new_accumulated_hidden = new_accumulated_hidden + torch.sum( - future_hidden, dim=1) + curr_hidden, dim=1) - predicted_sentiment = classifier(new_accumulated_hidden / ( - current_length + 1 + horizon_length)) + prediction = classifier(new_accumulated_hidden / + (curr_length + 1 + horizon_length)) - label = torch.tensor([label_class], device='cuda', + label = torch.tensor([label_class], device=device, dtype=torch.long) - discrim_loss = ce_loss(predicted_sentiment, label) + discrim_loss = ce_loss(prediction, label) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) loss += discrim_loss loss_list.append(discrim_loss) kl_loss = 0.0 if kl_scale > 0.0: - p = (F.softmax(unpert_logits[:, -1, :], dim=-1)) - p = p + SMALL_CONST * (p <= SMALL_CONST).type( - torch.FloatTensor).cuda().detach() - correction = SMALL_CONST * (probabs <= SMALL_CONST).type( - torch.FloatTensor).cuda().detach() - corrected_probabs = probabs + correction.detach() + unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) + unpert_probs = ( + unpert_probs + SMALL_CONST * + (unpert_probs <= SMALL_CONST).float().to(device).detach() + ) + correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach() + corrected_probs = probs + correction.detach() kl_loss = kl_scale * ( - (corrected_probabs * (corrected_probabs / p).log()).sum()) - print(' kl_loss', (kl_loss).data.cpu().numpy()) - loss += kl_loss # + discrim_loss + (corrected_probs * (corrected_probs / unpert_probs).log()).sum() + ) + print(' kl_loss', kl_loss.data.cpu().numpy()) + loss += kl_loss loss_per_iter.append(loss.data.cpu().numpy()) - print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) + # compute gradients loss.backward() - if grad_norms is not None and loss_type == 1: + + # calculate gradient norms + if grad_norms is not None and loss_type == PPLM_BOW: grad_norms = [ torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) - for index, p_ in - enumerate(past_perturb)] + for index, p_ in enumerate(curr_perturbation) + ] else: - grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST) for - index, p_ in enumerate(past_perturb)] + grad_norms = [ + (torch.norm(p_.grad * window_mask) + SMALL_CONST) + for index, p_ in enumerate(curr_perturbation) + ] + # normalize gradients grad = [ - -stepsize * (p_.grad * window_mask / grad_norms[ - index] ** gamma).data.cpu().numpy() - for index, p_ in enumerate(past_perturb)] - past_perturb_orig = list(map(add, grad, past_perturb_orig)) + -stepsize * + (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy() + for index, p_ in enumerate(curr_perturbation) + ] - for p_ in past_perturb: + # accumulate gradient + grad_accumulator = list(map(add, grad, grad_accumulator)) + + # reset gradients, just to make sure + for p_ in curr_perturbation: p_.grad.data.zero_() + # removing past from the graph new_past = [] - for p in past: - new_past.append(p.detach()) - + for p_ in past: + new_past.append(p_.detach()) past = new_past - past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] - past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] - perturbed_past = list(map(add, past, past_perturb)) + # apply the accumulated perturbations to the past + grad_accumulator = [ + to_var(torch.from_numpy(p_), requires_grad=True) + for p_ in grad_accumulator + ] + pert_past = list(map(add, past, grad_accumulator)) - return perturbed_past, new_accumulated_hidden, grad_norms, loss_per_iter + return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter def get_classifier( @@ -532,6 +550,7 @@ def generate_text_pplm( horizon_length=horizon_length, decay=decay, gamma=gamma, + device=device ) loss_in_time.append(loss_this_iter) else: @@ -562,7 +581,7 @@ def generate_text_pplm( pert_probs = ((pert_probs ** gm_scale) * ( unpert_probs ** (1 - gm_scale))) # + SMALL_CONST pert_probs = top_k_filter(pert_probs, k=top_k, - probs=True) # + SMALL_CONST + probs=True) # + SMALL_CONST # rescale if torch.sum(pert_probs) <= 1: @@ -662,7 +681,8 @@ def run_model(): parser.add_argument("--decay", action="store_true", help="whether to decay or not") parser.add_argument("--gamma", type=float, default=1.5) - parser.add_argument("--colorama", action="store_true", help="colors keywords") + parser.add_argument("--colorama", action="store_true", + help="colors keywords") args = parser.parse_args()