From 6c9c1317800499bd5aaf3f1a7492a72961bbfa91 Mon Sep 17 00:00:00 2001 From: piero Date: Wed, 27 Nov 2019 17:27:39 -0800 Subject: [PATCH] More cleanup for run_model. Identical output as before. --- examples/run_pplm.py | 308 ++++++++++++++++++++++++------------------- 1 file changed, 171 insertions(+), 137 deletions(-) diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 0d9ed86f45..27ead3c3c5 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -39,7 +39,6 @@ from transformers import GPT2Tokenizer from transformers.file_utils import cached_path from transformers.modeling_gpt2 import GPT2LMHeadModel - PPLM_BOW = 1 PPLM_DISCRIM = 2 PPLM_BOW_DISCRIM = 3 @@ -129,8 +128,7 @@ def perturb_past( decay=False, gamma=1.5, ): - - #def perturb_past(past, model, prev, classifier, good_index=None, + # def perturb_past(past, model, prev, classifier, good_index=None, # stepsize=0.01, vocab_size=50257, # original_probs=None, accumulated_hidden=None, true_past=None, # grad_norms=None): @@ -237,7 +235,7 @@ def perturb_past( future_hidden, dim=1) predicted_sentiment = classifier(new_accumulated_hidden / ( - current_length + 1 + horizon_length)) + current_length + 1 + horizon_length)) label = torch.tensor([label_class], device='cuda', dtype=torch.long) @@ -349,6 +347,13 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[ bow_indices.append( [TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in words]) + + #bow_words = set() + #for bow_list in bow_indices: + # bow_list = list(filter(lambda x: len(x) <= 1, bow_list)) + # bow_words.update( + # (TOKENIZER.decode(word).strip(), word) for word in bow_list) + return bow_indices @@ -368,28 +373,28 @@ def build_bows_one_hot_vectors(bow_indices): def full_text_generation( - model, - context=None, - num_samples=1, - device="cuda", - sample=True, - discrim=None, - label_class=None, - bag_of_words=None, - length=100, - grad_length=10000, - stepsize=0.02, - num_iterations=3, - temperature=1.0, - gm_scale=0.9, - kl_scale=0.01, - top_k=10, - window_length=0, - horizon_length=1, - decay=False, - gamma=1.5, - **kwargs - ): + model, + context=None, + num_samples=1, + device="cuda", + sample=True, + discrim=None, + label_class=None, + bag_of_words=None, + length=100, + grad_length=10000, + stepsize=0.02, + num_iterations=3, + temperature=1.0, + gm_scale=0.9, + kl_scale=0.01, + top_k=10, + window_length=0, + horizon_length=1, + decay=False, + gamma=1.5, + **kwargs +): classifier, class_id = get_classifier( discrim, label_class, @@ -465,15 +470,9 @@ def full_text_generation( # actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list] bow_indices = [] - actual_words = None if bag_of_words: bow_indices = get_bag_of_words_indices(bag_of_words.split(";")) - for good_list in bow_indices: - good_list = list(filter(lambda x: len(x) <= 1, good_list)) - actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in - good_list] - if bag_of_words and classifier: print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") loss_type = PPLM_BOW_DISCRIM @@ -533,8 +532,7 @@ def full_text_generation( torch.cuda.empty_cache() - return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words - + return original, perturbed_list, discrim_loss_list, loss_in_time_list def generate_text_pplm( @@ -611,25 +609,25 @@ def generate_text_pplm( accumulated_hidden = torch.sum(accumulated_hidden, dim=1) perturbed_past, _, grad_norms, loss_per_iter = perturb_past( - past, - model, - prev, - unpert_past=unpert_past, - unpert_logits=unpert_logits, - accumulated_hidden=accumulated_hidden, - grad_norms=grad_norms, - stepsize=current_stepsize, - classifier=classifier, - label_class=label_class, - one_hot_bows_vectors=one_hot_bows_vectors, - loss_type=loss_type, - num_iterations=num_iterations, - kl_scale=kl_scale, - window_length=window_length, - horizon_length=horizon_length, - decay=decay, - gamma=gamma, - ) + past, + model, + prev, + unpert_past=unpert_past, + unpert_logits=unpert_logits, + accumulated_hidden=accumulated_hidden, + grad_norms=grad_norms, + stepsize=current_stepsize, + classifier=classifier, + label_class=label_class, + one_hot_bows_vectors=one_hot_bows_vectors, + loss_type=loss_type, + num_iterations=num_iterations, + kl_scale=kl_scale, + window_length=window_length, + horizon_length=horizon_length, + decay=decay, + gamma=gamma, + ) loss_in_time.append(loss_per_iter) # Piero modified model call @@ -666,7 +664,7 @@ def generate_text_pplm( # print(TOKENIZER.decode(likelywords[1].tolist()[0])) log_probs = ((log_probs ** gm_scale) * ( - unpert_logits ** (1 - gm_scale))) # + SmallConst + unpert_logits ** (1 - gm_scale))) # + SmallConst log_probs = top_k_filter(log_probs, k=top_k, probs=True) # + SmallConst @@ -696,53 +694,88 @@ def generate_text_pplm( def run_model(): parser = argparse.ArgumentParser() - parser.add_argument('--model_path', '-M', type=str, default='gpt2-medium', - help='pretrained model name or path to local checkpoint') - parser.add_argument('--bag-of-words', '-B', type=str, default=None, - help='Bags of words used for PPLM-BoW. Multiple BoWs separated by ;') - parser.add_argument('--discrim', '-D', type=str, default=None, - choices=( - 'clickbait', 'sentiment', 'toxicity', 'generic'), - help='Discriminator to use for loss-type 2') - parser.add_argument('--discrim_weights', type=str, default=None, - help='Weights for the generic discriminator') - parser.add_argument('--discrim_meta', type=str, default=None, - help='Meta information for the generic discriminator') - parser.add_argument('--label_class', type=int, default=-1, - help='Class label used for the discriminator') - parser.add_argument('--stepsize', type=float, default=0.02) + parser.add_argument( + "--model_path", + "-M", + type=str, + default="gpt2-medium", + help="pretrained model name or path to local checkpoint", + ) + parser.add_argument( + "--bag_of_words", + "-B", + type=str, + default=None, + help="Bags of words used for PPLM-BoW. " + "Either a BOW id (see list in code) or a filepath. " + "Multiple BoWs separated by ;", + ) + parser.add_argument( + "--discrim", + "-D", + type=str, + default=None, + choices=("clickbait", "sentiment", "toxicity"), + help="Discriminator to use for loss-type 2", + ) + parser.add_argument( + "--label_class", + type=int, + default=-1, + help="Class label used for the discriminator", + ) + parser.add_argument("--stepsize", type=float, default=0.02) parser.add_argument("--length", type=int, default=100) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=10) parser.add_argument("--gm_scale", type=float, default=0.9) parser.add_argument("--kl_scale", type=float, default=0.01) - parser.add_argument('--nocuda', action='store_true', help='no cuda') - parser.add_argument('--uncond', action='store_true', - help='Generate from end-of-text as prefix') - parser.add_argument("--cond_text", type=str, default='The lake', - help='Prefix texts to condition on') - parser.add_argument('--num_iterations', type=int, default=3) - parser.add_argument('--grad_length', type=int, default=10000) - parser.add_argument('--num_samples', type=int, default=1, - help='Number of samples to generate from the modified latents') - parser.add_argument('--horizon_length', type=int, default=1, - help='Length of future to optimize over') - # parser.add_argument('--force-token', action='store_true', help='no cuda') - parser.add_argument('--window_length', type=int, default=0, - help='Length of past which is being optimizer; 0 corresponds to infinite window length') - parser.add_argument('--decay', action='store_true', - help='whether to decay or not') - parser.add_argument('--gamma', type=float, default=1.5) - parser.add_argument('--colorama', action='store_true', help='no cuda') + parser.add_argument("--no_cuda", action="store_true", help="no cuda") + parser.add_argument( + "--uncond", action="store_true", + help="Generate from end-of-text as prefix" + ) + parser.add_argument( + "--cond_text", type=str, default="The lake", + help="Prefix texts to condition on" + ) + parser.add_argument("--num_iterations", type=int, default=3) + parser.add_argument("--grad_length", type=int, default=10000) + parser.add_argument( + "--num_samples", + type=int, + default=1, + help="Number of samples to generate from the modified latents", + ) + parser.add_argument( + "--horizon_length", + type=int, + default=1, + help="Length of future to optimize over", + ) + parser.add_argument( + "--window_length", + type=int, + default=0, + help="Length of past which is being optimized; " + "0 corresponds to infinite window length", + ) + parser.add_argument("--decay", action="store_true", + help="whether to decay or not") + parser.add_argument("--gamma", type=float, default=1.5) + parser.add_argument("--colorama", action="store_true", help="colors keywords") args = parser.parse_args() + # set Random seed torch.manual_seed(args.seed) np.random.seed(args.seed) - device = 'cpu' if args.nocuda else 'cuda' + # set the device + device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + # load pretrained model model = GPT2LMHeadModel.from_pretrained( args.model_path, output_hidden_states=True @@ -753,76 +786,77 @@ def run_model(): # Freeze GPT-2 weights for param in model.parameters(): param.requires_grad = False - pass + # figure out conditioning text if args.uncond: - seq = [[50256, 50256]] - + tokenized_cond_text = TOKENIZER.encode( + [TOKENIZER.bos_token] + ) else: raw_text = args.cond_text while not raw_text: - print('Did you forget to add `--cond-text`? ') + print("Did you forget to add `--cond_text`? ") raw_text = input("Model prompt >>> ") - seq = [[50256] + TOKENIZER.encode(raw_text)] + tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text) - collect_gen = dict() - current_index = 0 - for tokenized_cond_text in seq: + print("= Prefix of sentence =") + print(TOKENIZER.decode(tokenized_cond_text)) + print() - text = TOKENIZER.decode(tokenized_cond_text) - print("=" * 40 + " Prefix of sentence " + "=" * 40) - print(text) - print("=" * 80) + # generate unperturbed and perturbed texts - out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation( - model=model, context=tokenized_cond_text, device=device, **vars(args) - ) + # full_text_generation returns: + # unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time + unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation( + model=model, context=tokenized_cond_text, device=device, **vars(args) + ) - text_whole = TOKENIZER.decode(out1.tolist()[0]) + # untokenize unperturbed text + unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0]) - print("=" * 80) - print("=" * 40 + " Whole sentence (Original)" + "=" * 40) - print(text_whole) - print("=" * 80) + print("=" * 80) + print("= Unperturbed generated text =") + print(unpert_gen_text) + print() - out_perturb_copy = out_perturb + generated_texts = [] - for out_perturb in out_perturb_copy: - # try: - # print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40) - # text_whole = TOKENIZER.decode(out_perturb.tolist()[0]) - # print(text_whole) - # print("=" * 80) - # except: - # pass - # collect_gen[current_index] = [out, out_perturb, out1] - ## Save the prefix, perturbed seq, original seq for each index - print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40) - keyword_tokens = [aa[-1][0] for aa in - actual_words] if actual_words else [] - output_tokens = out_perturb.tolist()[0] + bow_words = set() + bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";")) + for bow_list in bow_indices: + filtered = list(filter(lambda x: len(x) <= 1, bow_list)) + bow_words.update(w[0] for w in filtered) + # iterate through the perturbed texts + for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): + try: + # untokenize unperturbed text if args.colorama: import colorama - text_whole = '' - for tokenized_cond_text in output_tokens: - if tokenized_cond_text in keyword_tokens: - text_whole += '%s%s%s' % ( - colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]), - colorama.Style.RESET_ALL) + pert_gen_text = '' + for word_id in pert_gen_tok_text.tolist()[0]: + if word_id in bow_words: + pert_gen_text += '{}{}{}'.format( + colorama.Fore.RED, + TOKENIZER.decode([word_id]), + colorama.Style.RESET_ALL + ) else: - text_whole += TOKENIZER.decode([tokenized_cond_text]) + pert_gen_text += TOKENIZER.decode([word_id]) else: - text_whole = TOKENIZER.decode(out_perturb.tolist()[0]) + pert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0]) - print(text_whole) - print("=" * 80) - - collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1] - - current_index = current_index + 1 + print("= Perturbed generated text {} =".format(i + 1)) + print(pert_gen_text) + print() + except: + pass + # keep the prefix, perturbed seq, original seq for each index + generated_texts.append( + (tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text) + ) return