diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 4d335a9241..bd03bbe5e0 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -471,59 +471,49 @@ def generate_text_pplm( decay=False, gamma=1.5, ): - output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze( - 0) if context else None + output_so_far = ( + torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) + if context + else None + ) # collect one hot vectors for bags of words one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices) grad_norms = None + unpert_discrim_loss = 0 loss_in_time = [] for i in trange(length, ascii=True): # Get past/probs for current output, except for last word - # Note that GPT takes 2 inputs: past + current-token - # Therefore, use everything from before current i/p token to generate relevant past + # Note that GPT takes 2 inputs: past + current_token - if past is None and output is not None: - prev = output[:, -1:] - # _, past = model(output[:, :-1]) - # original_probs, true_past = model(output) - # true_hidden = model.hidden_states + # run model forward to obtain unperturbed + if past is None and output_so_far is not None: + last = output_so_far[:, -1:] + _, past, _ = model(output_so_far[:, :-1]) - # Piero modified model call - _, past, _ = model(output[:, :-1]) - unpert_logits, unpert_past, unpert_all_hidden = model(output) - true_hidden = unpert_all_hidden[-1] - - else: - # original_probs, true_past = model(output) - # true_hidden = model.hidden_states - - # Piero modified model call - unpert_logits, unpert_past, unpert_all_hidden = model(output) - true_hidden = unpert_all_hidden[-1] - - # Modify the past if necessary + unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) + unpert_last_hidden = unpert_all_hidden[-1] + # check if we are abowe grad max length if i >= grad_length: current_stepsize = stepsize * 0 else: current_stepsize = stepsize + # modify the past if necessary if not perturb or num_iterations == 0: - perturbed_past = past + pert_past = past else: - # Piero modified model call - # accumulated_hidden = model.hidden_states[:, :-1, :] - accumulated_hidden = true_hidden[:, :-1, :] + accumulated_hidden = unpert_last_hidden[:, :-1, :] accumulated_hidden = torch.sum(accumulated_hidden, dim=1) - perturbed_past, _, grad_norms, loss_per_iter = perturb_past( + pert_past, _, grad_norms, loss_this_iter = perturb_past( past, model, - prev, + last, unpert_past=unpert_past, unpert_logits=unpert_logits, accumulated_hidden=accumulated_hidden, @@ -540,68 +530,59 @@ def generate_text_pplm( decay=decay, gamma=gamma, ) - loss_in_time.append(loss_per_iter) + loss_in_time.append(loss_this_iter) - # Piero modified model call - logits, past, pert_all_hidden = model(prev, past=perturbed_past) - # test_logits = F.softmax(test_logits[:, -1, :], dim=-1) - # likelywords = torch.topk(test_logits, k=10, dim=-1) - # print(TOKENIZER.decode(likelywords[1].tolist()[0])) + pert_logits, past, pert_all_hidden = model(last, past=pert_past) + pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST + pert_probs = F.softmax(pert_logits, dim=-1) if classifier is not None: ce_loss = torch.nn.CrossEntropyLoss() - predicted_sentiment = classifier(torch.mean(true_hidden, dim=1)) + prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) label = torch.tensor([label_class], device='cuda', dtype=torch.long) - true_discrim_loss = ce_loss(predicted_sentiment, label) - print("true discrim loss", true_discrim_loss.data.cpu().numpy()) + unpert_discrim_loss = ce_loss(prediction, label) + print( + "unperturbed discrim loss", + unpert_discrim_loss.data.cpu().numpy() + ) else: - true_discrim_loss = 0 - - # Piero modified model call - # hidden = model.hidden_states # update hidden - # logits = model.forward_hidden(hidden) - logits = logits[:, -1, :] / temperature # + SMALL_CONST - - # logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST - - log_probs = F.softmax(logits, dim=-1) + unpert_discrim_loss = 0 # Fuse the modified model and original model if perturb: - # original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST - unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1) - # likelywords = torch.topk(original_probs, k=10, dim=-1) - # print(TOKENIZER.decode(likelywords[1].tolist()[0])) + unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) - log_probs = ((log_probs ** gm_scale) * ( - unpert_logits ** (1 - gm_scale))) # + SMALL_CONST - - log_probs = top_k_filter(log_probs, k=top_k, + pert_probs = ((pert_probs ** gm_scale) * ( + unpert_probs ** (1 - gm_scale))) # + SMALL_CONST + pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST - if torch.sum(log_probs) <= 1: - log_probs = log_probs / torch.sum(log_probs) + # rescale + if torch.sum(pert_probs) <= 1: + pert_probs = pert_probs / torch.sum(pert_probs) else: - logits = top_k_filter(logits, k=top_k) # + SMALL_CONST - log_probs = F.softmax(logits, dim=-1) + pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST + pert_probs = F.softmax(pert_logits, dim=-1) + # sample or greedy if sample: - # likelywords = torch.topk(log_probs, k=args.top_k, dim=-1) - # print(TOKENIZER.decode(likelywords[1].tolist()[0])) - # print(likelywords[0].tolist()) - prev = torch.multinomial(log_probs, num_samples=1) - else: - _, prev = torch.topk(log_probs, k=1, dim=-1) - # if perturb: - # prev = future - output = prev if output is None else torch.cat((output, prev), - dim=1) # update output - print(TOKENIZER.decode(output.tolist()[0])) + last = torch.multinomial(pert_probs, num_samples=1) - return output, true_discrim_loss, loss_in_time + else: + _, last = torch.topk(pert_probs, k=1, dim=-1) + + # update context/output_so_far appending the new token + output_so_far = ( + last if output_so_far is None + else torch.cat((output_so_far, last), dim=1) + ) + + print(TOKENIZER.decode(output_so_far.tolist()[0])) + + return output_so_far, unpert_discrim_loss, loss_in_time def run_model():