diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 27ead3c3c5..b85998d706 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -401,74 +401,6 @@ def full_text_generation( device ) - # if args.discrim == 'clickbait': - # classifier = ClassificationHead(class_size=2, embed_size=1024).to(device) - # classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt")) - # classifier.eval() - # args.label_class = 1 # clickbaity - # - # elif args.discrim == 'sentiment': - # classifier = ClassificationHead(class_size=5, embed_size=1024).to(device) - # #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt")) - # classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt")) - # classifier.eval() - # if args.label_class < 0: - # raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*') - # #args.label_class = 2 # very pos - # #args.label_class = 3 # very neg - # - # elif args.discrim == 'toxicity': - # classifier = ClassificationHead(class_size=2, embed_size=1024).to(device) - # classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt")) - # classifier.eval() - # args.label_class = 0 # not toxic - # - # elif args.discrim == 'generic': - # if args.discrim_weights is None: - # raise ValueError('When using a generic discriminator, ' - # 'discrim_weights need to be specified') - # if args.discrim_meta is None: - # raise ValueError('When using a generic discriminator, ' - # 'discrim_meta need to be specified') - # - # with open(args.discrim_meta, 'r') as discrim_meta_file: - # meta = json.load(discrim_meta_file) - # - # classifier = ClassificationHead( - # class_size=meta['class_size'], - # embed_size=meta['embed_size'], - # # todo add tokenizer from meta - # ).to(device) - # classifier.load_state_dict(torch.load(args.discrim_weights)) - # classifier.eval() - # if args.label_class == -1: - # args.label_class = meta['default_class'] - # - # else: - # classifier = None - - # Get tokens for the list of positive words - def list_tokens(word_list): - token_list = [TOKENIZER.encode(word, add_prefix_space=True) for word in - word_list] - # token_list = [] - # for word in word_list: - # token_list.append(TOKENIZER.encode(" " + word)) - return token_list - - # good_index = [] - # if args.bag_of_words: - # bags_of_words = args.bag_of_words.split(";") - # for wordlist in bags_of_words: - # with open(wordlist, "r") as f: - # words = f.read().strip() - # words = words.split('\n') - # good_index.append(list_tokens(words)) - # - # for good_list in good_index: - # good_list = list(filter(lambda x: len(x) <= 1, good_list)) - # actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list] - bow_indices = [] if bag_of_words: bow_indices = get_bag_of_words_indices(bag_of_words.split(";")) @@ -486,9 +418,9 @@ def full_text_generation( print("Using PPLM-Discrim") else: - raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)") + raise Exception("Specify either a bag of words or a discriminator") - original, _, _ = generate_text_pplm( + unpert_gen_tok_text, _, _ = generate_text_pplm( model=model, context=context, device=device, @@ -497,12 +429,12 @@ def full_text_generation( ) torch.cuda.empty_cache() - perturbed_list = [] - discrim_loss_list = [] - loss_in_time_list = [] + pert_gen_tok_texts = [] + discrim_losses = [] + losses_in_time = [] for i in range(num_samples): - perturbed, discrim_loss, loss_in_time = generate_text_pplm( + pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm( model=model, context=context, device=device, @@ -525,14 +457,14 @@ def full_text_generation( decay=decay, gamma=gamma, ) - perturbed_list.append(perturbed) + pert_gen_tok_texts.append(pert_gen_tok_text) if classifier is not None: - discrim_loss_list.append(discrim_loss.data.cpu().numpy()) - loss_in_time_list.append(loss_in_time) + discrim_losses.append(discrim_loss.data.cpu().numpy()) + losses_in_time.append(loss_in_time) torch.cuda.empty_cache() - return original, perturbed_list, discrim_loss_list, loss_in_time_list + return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time def generate_text_pplm( @@ -821,11 +753,14 @@ def run_model(): generated_texts = [] - bow_words = set() - bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";")) - for bow_list in bow_indices: - filtered = list(filter(lambda x: len(x) <= 1, bow_list)) - bow_words.update(w[0] for w in filtered) + bow_word_ids = set() + if args.bag_of_words and args.colorama: + bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";")) + for single_bow_list in bow_indices: + # filtering all words in the list composed of more than 1 token + filtered = list(filter(lambda x: len(x) <= 1, single_bow_list)) + # w[0] because we are sure w has only 1 item because previous fitler + bow_word_ids.update(w[0] for w in filtered) # iterate through the perturbed texts for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): @@ -836,7 +771,7 @@ def run_model(): pert_gen_text = '' for word_id in pert_gen_tok_text.tolist()[0]: - if word_id in bow_words: + if word_id in bow_word_ids: pert_gen_text += '{}{}{}'.format( colorama.Fore.RED, TOKENIZER.decode([word_id]),