Cleaned full_text_generation. Identical output as before.
This commit is contained in:
@@ -401,74 +401,6 @@ def full_text_generation(
|
|||||||
device
|
device
|
||||||
)
|
)
|
||||||
|
|
||||||
# if args.discrim == 'clickbait':
|
|
||||||
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
|
|
||||||
# classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt"))
|
|
||||||
# classifier.eval()
|
|
||||||
# args.label_class = 1 # clickbaity
|
|
||||||
#
|
|
||||||
# elif args.discrim == 'sentiment':
|
|
||||||
# classifier = ClassificationHead(class_size=5, embed_size=1024).to(device)
|
|
||||||
# #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt"))
|
|
||||||
# classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt"))
|
|
||||||
# classifier.eval()
|
|
||||||
# if args.label_class < 0:
|
|
||||||
# raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*')
|
|
||||||
# #args.label_class = 2 # very pos
|
|
||||||
# #args.label_class = 3 # very neg
|
|
||||||
#
|
|
||||||
# elif args.discrim == 'toxicity':
|
|
||||||
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
|
|
||||||
# classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt"))
|
|
||||||
# classifier.eval()
|
|
||||||
# args.label_class = 0 # not toxic
|
|
||||||
#
|
|
||||||
# elif args.discrim == 'generic':
|
|
||||||
# if args.discrim_weights is None:
|
|
||||||
# raise ValueError('When using a generic discriminator, '
|
|
||||||
# 'discrim_weights need to be specified')
|
|
||||||
# if args.discrim_meta is None:
|
|
||||||
# raise ValueError('When using a generic discriminator, '
|
|
||||||
# 'discrim_meta need to be specified')
|
|
||||||
#
|
|
||||||
# with open(args.discrim_meta, 'r') as discrim_meta_file:
|
|
||||||
# meta = json.load(discrim_meta_file)
|
|
||||||
#
|
|
||||||
# classifier = ClassificationHead(
|
|
||||||
# class_size=meta['class_size'],
|
|
||||||
# embed_size=meta['embed_size'],
|
|
||||||
# # todo add tokenizer from meta
|
|
||||||
# ).to(device)
|
|
||||||
# classifier.load_state_dict(torch.load(args.discrim_weights))
|
|
||||||
# classifier.eval()
|
|
||||||
# if args.label_class == -1:
|
|
||||||
# args.label_class = meta['default_class']
|
|
||||||
#
|
|
||||||
# else:
|
|
||||||
# classifier = None
|
|
||||||
|
|
||||||
# Get tokens for the list of positive words
|
|
||||||
def list_tokens(word_list):
|
|
||||||
token_list = [TOKENIZER.encode(word, add_prefix_space=True) for word in
|
|
||||||
word_list]
|
|
||||||
# token_list = []
|
|
||||||
# for word in word_list:
|
|
||||||
# token_list.append(TOKENIZER.encode(" " + word))
|
|
||||||
return token_list
|
|
||||||
|
|
||||||
# good_index = []
|
|
||||||
# if args.bag_of_words:
|
|
||||||
# bags_of_words = args.bag_of_words.split(";")
|
|
||||||
# for wordlist in bags_of_words:
|
|
||||||
# with open(wordlist, "r") as f:
|
|
||||||
# words = f.read().strip()
|
|
||||||
# words = words.split('\n')
|
|
||||||
# good_index.append(list_tokens(words))
|
|
||||||
#
|
|
||||||
# for good_list in good_index:
|
|
||||||
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
|
|
||||||
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
|
|
||||||
|
|
||||||
bow_indices = []
|
bow_indices = []
|
||||||
if bag_of_words:
|
if bag_of_words:
|
||||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
|
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
|
||||||
@@ -486,9 +418,9 @@ def full_text_generation(
|
|||||||
print("Using PPLM-Discrim")
|
print("Using PPLM-Discrim")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
|
raise Exception("Specify either a bag of words or a discriminator")
|
||||||
|
|
||||||
original, _, _ = generate_text_pplm(
|
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
||||||
model=model,
|
model=model,
|
||||||
context=context,
|
context=context,
|
||||||
device=device,
|
device=device,
|
||||||
@@ -497,12 +429,12 @@ def full_text_generation(
|
|||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
perturbed_list = []
|
pert_gen_tok_texts = []
|
||||||
discrim_loss_list = []
|
discrim_losses = []
|
||||||
loss_in_time_list = []
|
losses_in_time = []
|
||||||
|
|
||||||
for i in range(num_samples):
|
for i in range(num_samples):
|
||||||
perturbed, discrim_loss, loss_in_time = generate_text_pplm(
|
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
|
||||||
model=model,
|
model=model,
|
||||||
context=context,
|
context=context,
|
||||||
device=device,
|
device=device,
|
||||||
@@ -525,14 +457,14 @@ def full_text_generation(
|
|||||||
decay=decay,
|
decay=decay,
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
)
|
)
|
||||||
perturbed_list.append(perturbed)
|
pert_gen_tok_texts.append(pert_gen_tok_text)
|
||||||
if classifier is not None:
|
if classifier is not None:
|
||||||
discrim_loss_list.append(discrim_loss.data.cpu().numpy())
|
discrim_losses.append(discrim_loss.data.cpu().numpy())
|
||||||
loss_in_time_list.append(loss_in_time)
|
losses_in_time.append(loss_in_time)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return original, perturbed_list, discrim_loss_list, loss_in_time_list
|
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||||
|
|
||||||
|
|
||||||
def generate_text_pplm(
|
def generate_text_pplm(
|
||||||
@@ -821,11 +753,14 @@ def run_model():
|
|||||||
|
|
||||||
generated_texts = []
|
generated_texts = []
|
||||||
|
|
||||||
bow_words = set()
|
bow_word_ids = set()
|
||||||
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
|
if args.bag_of_words and args.colorama:
|
||||||
for bow_list in bow_indices:
|
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
|
||||||
filtered = list(filter(lambda x: len(x) <= 1, bow_list))
|
for single_bow_list in bow_indices:
|
||||||
bow_words.update(w[0] for w in filtered)
|
# filtering all words in the list composed of more than 1 token
|
||||||
|
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
|
||||||
|
# w[0] because we are sure w has only 1 item because previous fitler
|
||||||
|
bow_word_ids.update(w[0] for w in filtered)
|
||||||
|
|
||||||
# iterate through the perturbed texts
|
# iterate through the perturbed texts
|
||||||
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
|
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
|
||||||
@@ -836,7 +771,7 @@ def run_model():
|
|||||||
|
|
||||||
pert_gen_text = ''
|
pert_gen_text = ''
|
||||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||||
if word_id in bow_words:
|
if word_id in bow_word_ids:
|
||||||
pert_gen_text += '{}{}{}'.format(
|
pert_gen_text += '{}{}{}'.format(
|
||||||
colorama.Fore.RED,
|
colorama.Fore.RED,
|
||||||
TOKENIZER.decode([word_id]),
|
TOKENIZER.decode([word_id]),
|
||||||
|
|||||||
Reference in New Issue
Block a user