diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 217c131b8f..0d9ed86f45 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -109,20 +109,40 @@ def top_k_filter(logits, k, probs=False): logits) -def perturb_past(past, model, prev, args, classifier, good_index=None, - stepsize=0.01, vocab_size=50257, - original_probs=None, accumulated_hidden=None, true_past=None, - grad_norms=None): - window_length = args.window_length - gm_scale, kl_scale = args.gm_scale, args.kl_scale - one_hot_vectors = [] - for good_list in good_index: - good_list = list(filter(lambda x: len(x) <= 1, good_list)) - good_list = torch.tensor(good_list).cuda() - num_good = good_list.shape[0] - one_hot_good = torch.zeros(num_good, vocab_size).cuda() - one_hot_good.scatter_(1, good_list, 1) - one_hot_vectors.append(one_hot_good) +def perturb_past( + past, + model, + prev, + unpert_past=None, + unpert_logits=None, + accumulated_hidden=None, + grad_norms=None, + stepsize=0.01, + classifier=None, + label_class=None, + one_hot_bows_vectors=None, + loss_type=0, + num_iterations=3, + kl_scale=0.01, + window_length=0, + horizon_length=1, + decay=False, + gamma=1.5, +): + + #def perturb_past(past, model, prev, classifier, good_index=None, + # stepsize=0.01, vocab_size=50257, + # original_probs=None, accumulated_hidden=None, true_past=None, + # grad_norms=None): + + # one_hot_bows_vectors = [] + # for good_list in good_index: + # good_list = list(filter(lambda x: len(x) <= 1, good_list)) + # good_list = torch.tensor(good_list).cuda() + # num_good = good_list.shape[0] + # one_hot_good = torch.zeros(num_good, vocab_size).cuda() + # one_hot_good.scatter_(1, good_list, 1) + # one_hot_bows_vectors.append(one_hot_good) # Generate inital perturbed past past_perturb_orig = [ @@ -132,7 +152,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, if accumulated_hidden is None: accumulated_hidden = 0 - if args.decay: + if decay: decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[ 1:] else: @@ -160,7 +180,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, window_mask = torch.ones_like(past[0]).cuda() loss_per_iter = [] - for i in range(args.num_iterations): + for i in range(num_iterations): print("Iteration ", i + 1) past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] @@ -183,8 +203,8 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, probabs = F.softmax(logits, dim=-1) loss = 0.0 loss_list = [] - if args.loss_type == 1 or args.loss_type == 3: - for one_hot_good in one_hot_vectors: + if loss_type == 1 or loss_type == 3: + for one_hot_good in one_hot_bows_vectors: good_logits = torch.mm(probabs, torch.t(one_hot_good)) loss_word = good_logits loss_word = torch.sum(loss_word) @@ -194,10 +214,10 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, loss_list.append(loss_word) print(" pplm_bow_loss:", loss.data.cpu().numpy()) - if args.loss_type == 2 or args.loss_type == 3: + if loss_type == 2 or loss_type == 3: ce_loss = torch.nn.CrossEntropyLoss() - new_true_past = true_past - for i in range(args.horizon_length): + new_true_past = unpert_past + for i in range(horizon_length): future_probabs = F.softmax(logits, dim=-1) # Get softmax future_probabs = torch.unsqueeze(future_probabs, dim=1) @@ -217,9 +237,9 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, future_hidden, dim=1) predicted_sentiment = classifier(new_accumulated_hidden / ( - current_length + 1 + args.horizon_length)) + current_length + 1 + horizon_length)) - label = torch.tensor([args.label_class], device='cuda', + label = torch.tensor([label_class], device='cuda', dtype=torch.long) discrim_loss = ce_loss(predicted_sentiment, label) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) @@ -228,7 +248,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, kl_loss = 0.0 if kl_scale > 0.0: - p = (F.softmax(original_probs[:, -1, :], dim=-1)) + p = (F.softmax(unpert_logits[:, -1, :], dim=-1)) p = p + SmallConst * (p <= SmallConst).type( torch.FloatTensor).cuda().detach() correction = SmallConst * (probabs <= SmallConst).type( @@ -244,7 +264,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) loss.backward() - if grad_norms is not None and args.loss_type == 1: + if grad_norms is not None and loss_type == 1: grad_norms = [ torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) for index, p_ in @@ -255,7 +275,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, grad = [ -stepsize * (p_.grad * window_mask / grad_norms[ - index] ** args.gamma).data.cpu().numpy() + index] ** gamma).data.cpu().numpy() for index, p_ in enumerate(past_perturb)] past_perturb_orig = list(map(add, grad, past_perturb_orig)) @@ -347,10 +367,32 @@ def build_bows_one_hot_vectors(bow_indices): return one_hot_bows_vectors -def latent_perturb(model, args, context=None, sample=True, device='cuda'): +def full_text_generation( + model, + context=None, + num_samples=1, + device="cuda", + sample=True, + discrim=None, + label_class=None, + bag_of_words=None, + length=100, + grad_length=10000, + stepsize=0.02, + num_iterations=3, + temperature=1.0, + gm_scale=0.9, + kl_scale=0.01, + top_k=10, + window_length=0, + horizon_length=1, + decay=False, + gamma=1.5, + **kwargs + ): classifier, class_id = get_classifier( - args.discrim, - args.label_class, + discrim, + label_class, device ) @@ -422,49 +464,68 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'): # good_list = list(filter(lambda x: len(x) <= 1, good_list)) # actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list] - good_index = [] + bow_indices = [] actual_words = None - if args.bag_of_words: - good_index = get_bag_of_words_indices(args.bag_of_words.split(";")) + if bag_of_words: + bow_indices = get_bag_of_words_indices(bag_of_words.split(";")) - for good_list in good_index: + for good_list in bow_indices: good_list = list(filter(lambda x: len(x) <= 1, good_list)) actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in good_list] - if args.bag_of_words and classifier: + if bag_of_words and classifier: print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") - args.loss_type = PPLM_BOW_DISCRIM + loss_type = PPLM_BOW_DISCRIM - elif args.bag_of_words: - args.loss_type = PPLM_BOW + elif bag_of_words: + loss_type = PPLM_BOW print("Using PPLM-BoW") elif classifier is not None: - args.loss_type = PPLM_DISCRIM + loss_type = PPLM_DISCRIM print("Using PPLM-Discrim") else: raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)") - original, _, _ = sample_from_hidden(model=model, args=args, context=context, - device=device, - perturb=False, good_index=good_index, - classifier=classifier) + original, _, _ = generate_text_pplm( + model=model, + context=context, + device=device, + length=length, + perturb=False + ) torch.cuda.empty_cache() perturbed_list = [] discrim_loss_list = [] loss_in_time_list = [] - for i in range(args.num_samples): - perturbed, discrim_loss, loss_in_time = sample_from_hidden(model=model, - args=args, - context=context, - device=device, - perturb=True, - good_index=good_index, - classifier=classifier) + for i in range(num_samples): + perturbed, discrim_loss, loss_in_time = generate_text_pplm( + model=model, + context=context, + device=device, + sample=sample, + perturb=True, + bow_indices=bow_indices, + classifier=classifier, + label_class=class_id, + loss_type=loss_type, + length=length, + grad_length=grad_length, + stepsize=stepsize, + num_iterations=num_iterations, + temperature=temperature, + gm_scale=gm_scale, + kl_scale=kl_scale, + top_k=top_k, + window_length=window_length, + horizon_length=horizon_length, + decay=decay, + gamma=gamma, + ) perturbed_list.append(perturbed) if classifier is not None: discrim_loss_list.append(discrim_loss.data.cpu().numpy()) @@ -475,15 +536,40 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'): return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words -def sample_from_hidden(model, args, classifier, context=None, past=None, - device='cuda', - sample=True, perturb=True, good_index=None): + +def generate_text_pplm( + model, + context=None, + past=None, + device="cuda", + sample=True, + perturb=True, + classifier=None, + label_class=None, + bow_indices=None, + loss_type=0, + length=100, + grad_length=10000, + stepsize=0.02, + num_iterations=3, + temperature=1.0, + gm_scale=0.9, + kl_scale=0.01, + top_k=10, + window_length=0, + horizon_length=1, + decay=False, + gamma=1.5, +): output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze( 0) if context else None + # collect one hot vectors for bags of words + one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices) + grad_norms = None loss_in_time = [] - for i in trange(args.length, ascii=True): + for i in trange(length, ascii=True): # Get past/probs for current output, except for last word # Note that GPT takes 2 inputs: past + current-token @@ -497,7 +583,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, # Piero modified model call _, past, _ = model(output[:, :-1]) - original_probs, true_past, unpert_all_hidden = model(output) + unpert_logits, unpert_past, unpert_all_hidden = model(output) true_hidden = unpert_all_hidden[-1] else: @@ -505,17 +591,17 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, # true_hidden = model.hidden_states # Piero modified model call - original_probs, true_past, unpert_all_hidden = model(output) + unpert_logits, unpert_past, unpert_all_hidden = model(output) true_hidden = unpert_all_hidden[-1] # Modify the past if necessary - if i >= args.grad_length: - current_stepsize = args.stepsize * 0 + if i >= grad_length: + current_stepsize = stepsize * 0 else: - current_stepsize = args.stepsize + current_stepsize = stepsize - if not perturb or args.num_iterations == 0: + if not perturb or num_iterations == 0: perturbed_past = past else: @@ -524,17 +610,26 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, accumulated_hidden = true_hidden[:, :-1, :] accumulated_hidden = torch.sum(accumulated_hidden, dim=1) - perturbed_past, _, grad_norms, loss_per_iter = perturb_past(past, - model, - prev, - args, - good_index=good_index, - stepsize=current_stepsize, - original_probs=original_probs, - true_past=true_past, - accumulated_hidden=accumulated_hidden, - classifier=classifier, - grad_norms=grad_norms) + perturbed_past, _, grad_norms, loss_per_iter = perturb_past( + past, + model, + prev, + unpert_past=unpert_past, + unpert_logits=unpert_logits, + accumulated_hidden=accumulated_hidden, + grad_norms=grad_norms, + stepsize=current_stepsize, + classifier=classifier, + label_class=label_class, + one_hot_bows_vectors=one_hot_bows_vectors, + loss_type=loss_type, + num_iterations=num_iterations, + kl_scale=kl_scale, + window_length=window_length, + horizon_length=horizon_length, + decay=decay, + gamma=gamma, + ) loss_in_time.append(loss_per_iter) # Piero modified model call @@ -546,7 +641,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, if classifier is not None: ce_loss = torch.nn.CrossEntropyLoss() predicted_sentiment = classifier(torch.mean(true_hidden, dim=1)) - label = torch.tensor([args.label_class], device='cuda', + label = torch.tensor([label_class], device='cuda', dtype=torch.long) true_discrim_loss = ce_loss(predicted_sentiment, label) print("true discrim loss", true_discrim_loss.data.cpu().numpy()) @@ -556,7 +651,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, # Piero modified model call # hidden = model.hidden_states # update hidden # logits = model.forward_hidden(hidden) - logits = logits[:, -1, :] / args.temperature # + SmallConst + logits = logits[:, -1, :] / temperature # + SmallConst # logits = top_k_filter(logits, k=args.top_k) # + SmallConst @@ -566,22 +661,21 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, if perturb: # original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst - original_probs = F.softmax(original_probs[:, -1, :], dim=-1) + unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1) # likelywords = torch.topk(original_probs, k=10, dim=-1) # print(TOKENIZER.decode(likelywords[1].tolist()[0])) - gm_scale = args.gm_scale log_probs = ((log_probs ** gm_scale) * ( - original_probs ** (1 - gm_scale))) # + SmallConst + unpert_logits ** (1 - gm_scale))) # + SmallConst - log_probs = top_k_filter(log_probs, k=args.top_k, + log_probs = top_k_filter(log_probs, k=top_k, probs=True) # + SmallConst if torch.sum(log_probs) <= 1: log_probs = log_probs / torch.sum(log_probs) else: - logits = top_k_filter(logits, k=args.top_k) # + SmallConst + logits = top_k_filter(logits, k=top_k) # + SmallConst log_probs = F.softmax(logits, dim=-1) if sample: @@ -673,16 +767,16 @@ def run_model(): collect_gen = dict() current_index = 0 - for out in seq: + for tokenized_cond_text in seq: - text = TOKENIZER.decode(out) + text = TOKENIZER.decode(tokenized_cond_text) print("=" * 40 + " Prefix of sentence " + "=" * 40) print(text) print("=" * 80) - out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = latent_perturb( - model=model, args=args, context=out, - device=device) + out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation( + model=model, context=tokenized_cond_text, device=device, **vars(args) + ) text_whole = TOKENIZER.decode(out1.tolist()[0]) @@ -712,20 +806,20 @@ def run_model(): import colorama text_whole = '' - for out in output_tokens: - if out in keyword_tokens: + for tokenized_cond_text in output_tokens: + if tokenized_cond_text in keyword_tokens: text_whole += '%s%s%s' % ( - colorama.Fore.GREEN, TOKENIZER.decode([out]), + colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]), colorama.Style.RESET_ALL) else: - text_whole += TOKENIZER.decode([out]) + text_whole += TOKENIZER.decode([tokenized_cond_text]) else: text_whole = TOKENIZER.decode(out_perturb.tolist()[0]) print(text_whole) print("=" * 80) - collect_gen[current_index] = [out, out_perturb, out1] + collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1] current_index = current_index + 1