Imrpovements: model_path renamed pretrained_model, tokenizer loaded from pretrained_model, pretrained_model set to discriminator's when discrim is specified, sample = False by default but cli parameter introduced. To obtain identical samples call the cli with --sample
This commit is contained in:
committed by
Julien Chaumond
parent
75904dae66
commit
f10b925015
@@ -43,7 +43,6 @@ PPLM_DISCRIM = 2
|
|||||||
PPLM_BOW_DISCRIM = 3
|
PPLM_BOW_DISCRIM = 3
|
||||||
SMALL_CONST = 1e-15
|
SMALL_CONST = 1e-15
|
||||||
BIG_CONST = 1e10
|
BIG_CONST = 1e10
|
||||||
TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium")
|
|
||||||
|
|
||||||
BAG_OF_WORDS_ARCHIVE_MAP = {
|
BAG_OF_WORDS_ARCHIVE_MAP = {
|
||||||
'kitchen': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/kitchen.txt",
|
'kitchen': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/kitchen.txt",
|
||||||
@@ -65,6 +64,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
|
|||||||
"embed_size": 1024,
|
"embed_size": 1024,
|
||||||
"class_vocab": {"non_clickbait": 0, "clickbait": 1},
|
"class_vocab": {"non_clickbait": 0, "clickbait": 1},
|
||||||
"default_class": 1,
|
"default_class": 1,
|
||||||
|
"pretrained_model": "gpt2-medium",
|
||||||
},
|
},
|
||||||
"sentiment": {
|
"sentiment": {
|
||||||
"url": "http://s.yosinski.com/SST_classifier_head.pt",
|
"url": "http://s.yosinski.com/SST_classifier_head.pt",
|
||||||
@@ -72,6 +72,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
|
|||||||
"embed_size": 1024,
|
"embed_size": 1024,
|
||||||
"class_vocab": {"very_positive": 2, "very_negative": 3},
|
"class_vocab": {"very_positive": 2, "very_negative": 3},
|
||||||
"default_class": 3,
|
"default_class": 3,
|
||||||
|
"pretrained_model": "gpt2-medium",
|
||||||
},
|
},
|
||||||
"toxicity": {
|
"toxicity": {
|
||||||
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/toxicity_classifierhead.pt",
|
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/toxicity_classifierhead.pt",
|
||||||
@@ -79,6 +80,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
|
|||||||
"embed_size": 1024,
|
"embed_size": 1024,
|
||||||
"class_vocab": {"non_toxic": 0, "toxic": 1},
|
"class_vocab": {"non_toxic": 0, "toxic": 1},
|
||||||
"default_class": 0,
|
"default_class": 0,
|
||||||
|
"pretrained_model": "gpt2-medium",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -345,8 +347,9 @@ def get_classifier(
|
|||||||
return classifier, label_id
|
return classifier, label_id
|
||||||
|
|
||||||
|
|
||||||
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
|
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
|
||||||
List[List[int]]]:
|
List[
|
||||||
|
List[List[int]]]:
|
||||||
bow_indices = []
|
bow_indices = []
|
||||||
for id_or_path in bag_of_words_ids_or_paths:
|
for id_or_path in bag_of_words_ids_or_paths:
|
||||||
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
|
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
|
||||||
@@ -356,12 +359,12 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
|
|||||||
with open(filepath, "r") as f:
|
with open(filepath, "r") as f:
|
||||||
words = f.read().strip().split("\n")
|
words = f.read().strip().split("\n")
|
||||||
bow_indices.append(
|
bow_indices.append(
|
||||||
[TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in
|
[tokenizer.encode(word.strip(), add_prefix_space=True) for word in
|
||||||
words])
|
words])
|
||||||
return bow_indices
|
return bow_indices
|
||||||
|
|
||||||
|
|
||||||
def build_bows_one_hot_vectors(bow_indices, device='cuda'):
|
def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
|
||||||
if bow_indices is None:
|
if bow_indices is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -370,7 +373,7 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'):
|
|||||||
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
|
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
|
||||||
single_bow = torch.tensor(single_bow).to(device)
|
single_bow = torch.tensor(single_bow).to(device)
|
||||||
num_words = single_bow.shape[0]
|
num_words = single_bow.shape[0]
|
||||||
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).to(device)
|
one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
|
||||||
one_hot_bow.scatter_(1, single_bow, 1)
|
one_hot_bow.scatter_(1, single_bow, 1)
|
||||||
one_hot_bows_vectors.append(one_hot_bow)
|
one_hot_bows_vectors.append(one_hot_bow)
|
||||||
return one_hot_bows_vectors
|
return one_hot_bows_vectors
|
||||||
@@ -378,10 +381,11 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'):
|
|||||||
|
|
||||||
def full_text_generation(
|
def full_text_generation(
|
||||||
model,
|
model,
|
||||||
|
tokenizer,
|
||||||
context=None,
|
context=None,
|
||||||
num_samples=1,
|
num_samples=1,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
sample=True,
|
sample=False,
|
||||||
discrim=None,
|
discrim=None,
|
||||||
class_label=None,
|
class_label=None,
|
||||||
bag_of_words=None,
|
bag_of_words=None,
|
||||||
@@ -407,7 +411,8 @@ def full_text_generation(
|
|||||||
|
|
||||||
bow_indices = []
|
bow_indices = []
|
||||||
if bag_of_words:
|
if bag_of_words:
|
||||||
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
|
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
|
||||||
|
tokenizer)
|
||||||
|
|
||||||
if bag_of_words and classifier:
|
if bag_of_words and classifier:
|
||||||
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
|
||||||
@@ -426,9 +431,11 @@ def full_text_generation(
|
|||||||
|
|
||||||
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
||||||
model=model,
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
context=context,
|
context=context,
|
||||||
device=device,
|
device=device,
|
||||||
length=length,
|
length=length,
|
||||||
|
sample=sample,
|
||||||
perturb=False
|
perturb=False
|
||||||
)
|
)
|
||||||
if device == 'cuda':
|
if device == 'cuda':
|
||||||
@@ -441,6 +448,7 @@ def full_text_generation(
|
|||||||
for i in range(num_samples):
|
for i in range(num_samples):
|
||||||
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
|
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
|
||||||
model=model,
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
context=context,
|
context=context,
|
||||||
device=device,
|
device=device,
|
||||||
sample=sample,
|
sample=sample,
|
||||||
@@ -475,10 +483,11 @@ def full_text_generation(
|
|||||||
|
|
||||||
def generate_text_pplm(
|
def generate_text_pplm(
|
||||||
model,
|
model,
|
||||||
|
tokenizer,
|
||||||
context=None,
|
context=None,
|
||||||
past=None,
|
past=None,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
sample=True,
|
sample=False,
|
||||||
perturb=True,
|
perturb=True,
|
||||||
classifier=None,
|
classifier=None,
|
||||||
class_label=None,
|
class_label=None,
|
||||||
@@ -504,7 +513,8 @@ def generate_text_pplm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# collect one hot vectors for bags of words
|
# collect one hot vectors for bags of words
|
||||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, device)
|
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
|
||||||
|
device)
|
||||||
|
|
||||||
grad_norms = None
|
grad_norms = None
|
||||||
last = None
|
last = None
|
||||||
@@ -612,7 +622,7 @@ def generate_text_pplm(
|
|||||||
else torch.cat((output_so_far, last), dim=1)
|
else torch.cat((output_so_far, last), dim=1)
|
||||||
)
|
)
|
||||||
|
|
||||||
print(TOKENIZER.decode(output_so_far.tolist()[0]))
|
print(tokenizer.decode(output_so_far.tolist()[0]))
|
||||||
|
|
||||||
return output_so_far, unpert_discrim_loss, loss_in_time
|
return output_so_far, unpert_discrim_loss, loss_in_time
|
||||||
|
|
||||||
@@ -631,10 +641,167 @@ def set_generic_model_params(discrim_weights, discrim_meta):
|
|||||||
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
||||||
|
|
||||||
|
|
||||||
def run_model():
|
def run_pplm_example(
|
||||||
|
pretrained_model="gpt2-medium",
|
||||||
|
cond_text="",
|
||||||
|
uncond=False,
|
||||||
|
num_samples=1,
|
||||||
|
bag_of_words=None,
|
||||||
|
discrim=None,
|
||||||
|
discrim_weights=None,
|
||||||
|
discrim_meta=None,
|
||||||
|
class_label=-1,
|
||||||
|
length=100,
|
||||||
|
stepsize=0.02,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=10,
|
||||||
|
sample=False,
|
||||||
|
num_iterations=3,
|
||||||
|
grad_length=10000,
|
||||||
|
horizon_length=1,
|
||||||
|
window_length=0,
|
||||||
|
decay=False,
|
||||||
|
gamma=1.5,
|
||||||
|
gm_scale=0.9,
|
||||||
|
kl_scale=0.01,
|
||||||
|
seed=0,
|
||||||
|
no_cuda=False,
|
||||||
|
colorama=False
|
||||||
|
):
|
||||||
|
# set Random seed
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
# set the device
|
||||||
|
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||||
|
|
||||||
|
if discrim == 'generic':
|
||||||
|
set_generic_model_params(discrim_weights, discrim_meta)
|
||||||
|
|
||||||
|
if discrim is not None:
|
||||||
|
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
|
||||||
|
"pretrained_model"
|
||||||
|
]
|
||||||
|
print("discrim = {}, setting pretrained_model "
|
||||||
|
"to discriminator's = {}".format(discrim, pretrained_model))
|
||||||
|
|
||||||
|
# load pretrained model
|
||||||
|
model = GPT2LMHeadModel.from_pretrained(
|
||||||
|
pretrained_model,
|
||||||
|
output_hidden_states=True
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# load tokenizer
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
||||||
|
|
||||||
|
# Freeze GPT-2 weights
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# figure out conditioning text
|
||||||
|
if uncond:
|
||||||
|
tokenized_cond_text = tokenizer.encode(
|
||||||
|
[tokenizer.bos_token]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_text = cond_text
|
||||||
|
while not raw_text:
|
||||||
|
print("Did you forget to add `--cond_text`? ")
|
||||||
|
raw_text = input("Model prompt >>> ")
|
||||||
|
tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text)
|
||||||
|
|
||||||
|
print("= Prefix of sentence =")
|
||||||
|
print(tokenizer.decode(tokenized_cond_text))
|
||||||
|
print()
|
||||||
|
|
||||||
|
# generate unperturbed and perturbed texts
|
||||||
|
|
||||||
|
# full_text_generation returns:
|
||||||
|
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||||
|
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
context=tokenized_cond_text,
|
||||||
|
device=device,
|
||||||
|
num_samples=num_samples,
|
||||||
|
bag_of_words=bag_of_words,
|
||||||
|
discrim=discrim,
|
||||||
|
class_label=class_label,
|
||||||
|
length=length,
|
||||||
|
stepsize=stepsize,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
sample=sample,
|
||||||
|
num_iterations=num_iterations,
|
||||||
|
grad_length=grad_length,
|
||||||
|
horizon_length=horizon_length,
|
||||||
|
window_length=window_length,
|
||||||
|
decay=decay,
|
||||||
|
gamma=gamma,
|
||||||
|
gm_scale=gm_scale,
|
||||||
|
kl_scale=kl_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# untokenize unperturbed text
|
||||||
|
unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("= Unperturbed generated text =")
|
||||||
|
print(unpert_gen_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
generated_texts = []
|
||||||
|
|
||||||
|
bow_word_ids = set()
|
||||||
|
if bag_of_words and colorama:
|
||||||
|
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
|
||||||
|
tokenizer)
|
||||||
|
for single_bow_list in bow_indices:
|
||||||
|
# filtering all words in the list composed of more than 1 token
|
||||||
|
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
|
||||||
|
# w[0] because we are sure w has only 1 item because previous fitler
|
||||||
|
bow_word_ids.update(w[0] for w in filtered)
|
||||||
|
|
||||||
|
# iterate through the perturbed texts
|
||||||
|
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
|
||||||
|
try:
|
||||||
|
# untokenize unperturbed text
|
||||||
|
if colorama:
|
||||||
|
import colorama
|
||||||
|
|
||||||
|
pert_gen_text = ''
|
||||||
|
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||||
|
if word_id in bow_word_ids:
|
||||||
|
pert_gen_text += '{}{}{}'.format(
|
||||||
|
colorama.Fore.RED,
|
||||||
|
tokenizer.decode([word_id]),
|
||||||
|
colorama.Style.RESET_ALL
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pert_gen_text += tokenizer.decode([word_id])
|
||||||
|
else:
|
||||||
|
pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
|
||||||
|
|
||||||
|
print("= Perturbed generated text {} =".format(i + 1))
|
||||||
|
print(pert_gen_text)
|
||||||
|
print()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# keep the prefix, perturbed seq, original seq for each index
|
||||||
|
generated_texts.append(
|
||||||
|
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path",
|
"--pretrained_model",
|
||||||
"-M",
|
"-M",
|
||||||
type=str,
|
type=str,
|
||||||
default="gpt2-medium",
|
default="gpt2-medium",
|
||||||
@@ -675,6 +842,10 @@ def run_model():
|
|||||||
parser.add_argument("--gm_scale", type=float, default=0.9)
|
parser.add_argument("--gm_scale", type=float, default=0.9)
|
||||||
parser.add_argument("--kl_scale", type=float, default=0.01)
|
parser.add_argument("--kl_scale", type=float, default=0.01)
|
||||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample", action="store_true",
|
||||||
|
help="Generate from end-of-text as prefix"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--uncond", action="store_true",
|
"--uncond", action="store_true",
|
||||||
help="Generate from end-of-text as prefix"
|
help="Generate from end-of-text as prefix"
|
||||||
@@ -711,105 +882,4 @@ def run_model():
|
|||||||
help="colors keywords")
|
help="colors keywords")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
run_pplm_example(**vars(args))
|
||||||
# set Random seed
|
|
||||||
torch.manual_seed(args.seed)
|
|
||||||
np.random.seed(args.seed)
|
|
||||||
|
|
||||||
# set the device
|
|
||||||
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
|
||||||
|
|
||||||
if args.discrim == 'generic':
|
|
||||||
set_generic_model_params(args.discrim_weights, args.discrim_meta)
|
|
||||||
|
|
||||||
# load pretrained model
|
|
||||||
model = GPT2LMHeadModel.from_pretrained(
|
|
||||||
args.model_path,
|
|
||||||
output_hidden_states=True
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Freeze GPT-2 weights
|
|
||||||
for param in model.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
# figure out conditioning text
|
|
||||||
if args.uncond:
|
|
||||||
tokenized_cond_text = TOKENIZER.encode(
|
|
||||||
[TOKENIZER.bos_token]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raw_text = args.cond_text
|
|
||||||
while not raw_text:
|
|
||||||
print("Did you forget to add `--cond_text`? ")
|
|
||||||
raw_text = input("Model prompt >>> ")
|
|
||||||
tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text)
|
|
||||||
|
|
||||||
print("= Prefix of sentence =")
|
|
||||||
print(TOKENIZER.decode(tokenized_cond_text))
|
|
||||||
print()
|
|
||||||
|
|
||||||
# generate unperturbed and perturbed texts
|
|
||||||
|
|
||||||
# full_text_generation returns:
|
|
||||||
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
|
||||||
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
|
|
||||||
model=model, context=tokenized_cond_text, device=device, **vars(args)
|
|
||||||
)
|
|
||||||
|
|
||||||
# untokenize unperturbed text
|
|
||||||
unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0])
|
|
||||||
|
|
||||||
print("=" * 80)
|
|
||||||
print("= Unperturbed generated text =")
|
|
||||||
print(unpert_gen_text)
|
|
||||||
print()
|
|
||||||
|
|
||||||
generated_texts = []
|
|
||||||
|
|
||||||
bow_word_ids = set()
|
|
||||||
if args.bag_of_words and args.colorama:
|
|
||||||
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
|
|
||||||
for single_bow_list in bow_indices:
|
|
||||||
# filtering all words in the list composed of more than 1 token
|
|
||||||
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
|
|
||||||
# w[0] because we are sure w has only 1 item because previous fitler
|
|
||||||
bow_word_ids.update(w[0] for w in filtered)
|
|
||||||
|
|
||||||
# iterate through the perturbed texts
|
|
||||||
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
|
|
||||||
try:
|
|
||||||
# untokenize unperturbed text
|
|
||||||
if args.colorama:
|
|
||||||
import colorama
|
|
||||||
|
|
||||||
pert_gen_text = ''
|
|
||||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
|
||||||
if word_id in bow_word_ids:
|
|
||||||
pert_gen_text += '{}{}{}'.format(
|
|
||||||
colorama.Fore.RED,
|
|
||||||
TOKENIZER.decode([word_id]),
|
|
||||||
colorama.Style.RESET_ALL
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pert_gen_text += TOKENIZER.decode([word_id])
|
|
||||||
else:
|
|
||||||
pert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0])
|
|
||||||
|
|
||||||
print("= Perturbed generated text {} =".format(i + 1))
|
|
||||||
print(pert_gen_text)
|
|
||||||
print()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# keep the prefix, perturbed seq, original seq for each index
|
|
||||||
generated_texts.append(
|
|
||||||
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
run_model()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user