diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 28aa66cc7d..8516454f86 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -43,7 +43,6 @@ PPLM_DISCRIM = 2 PPLM_BOW_DISCRIM = 3 SMALL_CONST = 1e-15 BIG_CONST = 1e10 -TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium") BAG_OF_WORDS_ARCHIVE_MAP = { 'kitchen': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/kitchen.txt", @@ -65,6 +64,7 @@ DISCRIMINATOR_MODELS_PARAMS = { "embed_size": 1024, "class_vocab": {"non_clickbait": 0, "clickbait": 1}, "default_class": 1, + "pretrained_model": "gpt2-medium", }, "sentiment": { "url": "http://s.yosinski.com/SST_classifier_head.pt", @@ -72,6 +72,7 @@ DISCRIMINATOR_MODELS_PARAMS = { "embed_size": 1024, "class_vocab": {"very_positive": 2, "very_negative": 3}, "default_class": 3, + "pretrained_model": "gpt2-medium", }, "toxicity": { "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/toxicity_classifierhead.pt", @@ -79,6 +80,7 @@ DISCRIMINATOR_MODELS_PARAMS = { "embed_size": 1024, "class_vocab": {"non_toxic": 0, "toxic": 1}, "default_class": 0, + "pretrained_model": "gpt2-medium", }, } @@ -345,8 +347,9 @@ def get_classifier( return classifier, label_id -def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[ - List[List[int]]]: +def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \ + List[ + List[List[int]]]: bow_indices = [] for id_or_path in bag_of_words_ids_or_paths: if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: @@ -356,12 +359,12 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[ with open(filepath, "r") as f: words = f.read().strip().split("\n") bow_indices.append( - [TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in + [tokenizer.encode(word.strip(), add_prefix_space=True) for word in words]) return bow_indices -def build_bows_one_hot_vectors(bow_indices, device='cuda'): +def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'): if bow_indices is None: return None @@ -370,7 +373,7 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'): single_bow = list(filter(lambda x: len(x) <= 1, single_bow)) single_bow = torch.tensor(single_bow).to(device) num_words = single_bow.shape[0] - one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).to(device) + one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device) one_hot_bow.scatter_(1, single_bow, 1) one_hot_bows_vectors.append(one_hot_bow) return one_hot_bows_vectors @@ -378,10 +381,11 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'): def full_text_generation( model, + tokenizer, context=None, num_samples=1, device="cuda", - sample=True, + sample=False, discrim=None, class_label=None, bag_of_words=None, @@ -407,7 +411,8 @@ def full_text_generation( bow_indices = [] if bag_of_words: - bow_indices = get_bag_of_words_indices(bag_of_words.split(";")) + bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), + tokenizer) if bag_of_words and classifier: print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") @@ -426,9 +431,11 @@ def full_text_generation( unpert_gen_tok_text, _, _ = generate_text_pplm( model=model, + tokenizer=tokenizer, context=context, device=device, length=length, + sample=sample, perturb=False ) if device == 'cuda': @@ -441,6 +448,7 @@ def full_text_generation( for i in range(num_samples): pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm( model=model, + tokenizer=tokenizer, context=context, device=device, sample=sample, @@ -475,10 +483,11 @@ def full_text_generation( def generate_text_pplm( model, + tokenizer, context=None, past=None, device="cuda", - sample=True, + sample=False, perturb=True, classifier=None, class_label=None, @@ -504,7 +513,8 @@ def generate_text_pplm( ) # collect one hot vectors for bags of words - one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, device) + one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, + device) grad_norms = None last = None @@ -612,7 +622,7 @@ def generate_text_pplm( else torch.cat((output_so_far, last), dim=1) ) - print(TOKENIZER.decode(output_so_far.tolist()[0])) + print(tokenizer.decode(output_so_far.tolist()[0])) return output_so_far, unpert_discrim_loss, loss_in_time @@ -631,10 +641,167 @@ def set_generic_model_params(discrim_weights, discrim_meta): DISCRIMINATOR_MODELS_PARAMS['generic'] = meta -def run_model(): +def run_pplm_example( + pretrained_model="gpt2-medium", + cond_text="", + uncond=False, + num_samples=1, + bag_of_words=None, + discrim=None, + discrim_weights=None, + discrim_meta=None, + class_label=-1, + length=100, + stepsize=0.02, + temperature=1.0, + top_k=10, + sample=False, + num_iterations=3, + grad_length=10000, + horizon_length=1, + window_length=0, + decay=False, + gamma=1.5, + gm_scale=0.9, + kl_scale=0.01, + seed=0, + no_cuda=False, + colorama=False +): + # set Random seed + torch.manual_seed(seed) + np.random.seed(seed) + + # set the device + device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" + + if discrim == 'generic': + set_generic_model_params(discrim_weights, discrim_meta) + + if discrim is not None: + pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][ + "pretrained_model" + ] + print("discrim = {}, setting pretrained_model " + "to discriminator's = {}".format(discrim, pretrained_model)) + + # load pretrained model + model = GPT2LMHeadModel.from_pretrained( + pretrained_model, + output_hidden_states=True + ) + model.to(device) + model.eval() + + # load tokenizer + tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) + + # Freeze GPT-2 weights + for param in model.parameters(): + param.requires_grad = False + + # figure out conditioning text + if uncond: + tokenized_cond_text = tokenizer.encode( + [tokenizer.bos_token] + ) + else: + raw_text = cond_text + while not raw_text: + print("Did you forget to add `--cond_text`? ") + raw_text = input("Model prompt >>> ") + tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text) + + print("= Prefix of sentence =") + print(tokenizer.decode(tokenized_cond_text)) + print() + + # generate unperturbed and perturbed texts + + # full_text_generation returns: + # unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time + unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation( + model=model, + tokenizer=tokenizer, + context=tokenized_cond_text, + device=device, + num_samples=num_samples, + bag_of_words=bag_of_words, + discrim=discrim, + class_label=class_label, + length=length, + stepsize=stepsize, + temperature=temperature, + top_k=top_k, + sample=sample, + num_iterations=num_iterations, + grad_length=grad_length, + horizon_length=horizon_length, + window_length=window_length, + decay=decay, + gamma=gamma, + gm_scale=gm_scale, + kl_scale=kl_scale, + ) + + # untokenize unperturbed text + unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0]) + + print("=" * 80) + print("= Unperturbed generated text =") + print(unpert_gen_text) + print() + + generated_texts = [] + + bow_word_ids = set() + if bag_of_words and colorama: + bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), + tokenizer) + for single_bow_list in bow_indices: + # filtering all words in the list composed of more than 1 token + filtered = list(filter(lambda x: len(x) <= 1, single_bow_list)) + # w[0] because we are sure w has only 1 item because previous fitler + bow_word_ids.update(w[0] for w in filtered) + + # iterate through the perturbed texts + for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): + try: + # untokenize unperturbed text + if colorama: + import colorama + + pert_gen_text = '' + for word_id in pert_gen_tok_text.tolist()[0]: + if word_id in bow_word_ids: + pert_gen_text += '{}{}{}'.format( + colorama.Fore.RED, + tokenizer.decode([word_id]), + colorama.Style.RESET_ALL + ) + else: + pert_gen_text += tokenizer.decode([word_id]) + else: + pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0]) + + print("= Perturbed generated text {} =".format(i + 1)) + print(pert_gen_text) + print() + except: + pass + + # keep the prefix, perturbed seq, original seq for each index + generated_texts.append( + (tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text) + ) + + return + + +if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--model_path", + "--pretrained_model", "-M", type=str, default="gpt2-medium", @@ -675,6 +842,10 @@ def run_model(): parser.add_argument("--gm_scale", type=float, default=0.9) parser.add_argument("--kl_scale", type=float, default=0.01) parser.add_argument("--no_cuda", action="store_true", help="no cuda") + parser.add_argument( + "--sample", action="store_true", + help="Generate from end-of-text as prefix" + ) parser.add_argument( "--uncond", action="store_true", help="Generate from end-of-text as prefix" @@ -711,105 +882,4 @@ def run_model(): help="colors keywords") args = parser.parse_args() - - # set Random seed - torch.manual_seed(args.seed) - np.random.seed(args.seed) - - # set the device - device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" - - if args.discrim == 'generic': - set_generic_model_params(args.discrim_weights, args.discrim_meta) - - # load pretrained model - model = GPT2LMHeadModel.from_pretrained( - args.model_path, - output_hidden_states=True - ) - model.to(device) - model.eval() - - # Freeze GPT-2 weights - for param in model.parameters(): - param.requires_grad = False - - # figure out conditioning text - if args.uncond: - tokenized_cond_text = TOKENIZER.encode( - [TOKENIZER.bos_token] - ) - else: - raw_text = args.cond_text - while not raw_text: - print("Did you forget to add `--cond_text`? ") - raw_text = input("Model prompt >>> ") - tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text) - - print("= Prefix of sentence =") - print(TOKENIZER.decode(tokenized_cond_text)) - print() - - # generate unperturbed and perturbed texts - - # full_text_generation returns: - # unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time - unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation( - model=model, context=tokenized_cond_text, device=device, **vars(args) - ) - - # untokenize unperturbed text - unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0]) - - print("=" * 80) - print("= Unperturbed generated text =") - print(unpert_gen_text) - print() - - generated_texts = [] - - bow_word_ids = set() - if args.bag_of_words and args.colorama: - bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";")) - for single_bow_list in bow_indices: - # filtering all words in the list composed of more than 1 token - filtered = list(filter(lambda x: len(x) <= 1, single_bow_list)) - # w[0] because we are sure w has only 1 item because previous fitler - bow_word_ids.update(w[0] for w in filtered) - - # iterate through the perturbed texts - for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): - try: - # untokenize unperturbed text - if args.colorama: - import colorama - - pert_gen_text = '' - for word_id in pert_gen_tok_text.tolist()[0]: - if word_id in bow_word_ids: - pert_gen_text += '{}{}{}'.format( - colorama.Fore.RED, - TOKENIZER.decode([word_id]), - colorama.Style.RESET_ALL - ) - else: - pert_gen_text += TOKENIZER.decode([word_id]) - else: - pert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0]) - - print("= Perturbed generated text {} =".format(i + 1)) - print(pert_gen_text) - print() - except: - pass - - # keep the prefix, perturbed seq, original seq for each index - generated_texts.append( - (tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text) - ) - - return - - -if __name__ == '__main__': - run_model() + run_pplm_example(**vars(args))