removed deprecared use of Variable api from pplm example
This commit is contained in:
committed by
Julien Chaumond
parent
12d0eb5f3e
commit
48a05026de
@@ -31,7 +31,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.autograd import Variable
|
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from pplm_classification_head import ClassificationHead
|
from pplm_classification_head import ClassificationHead
|
||||||
@@ -76,14 +75,6 @@ DISCRIMINATOR_MODELS_PARAMS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def to_var(x, requires_grad=False, volatile=False, device="cuda"):
|
|
||||||
if torch.cuda.is_available() and device == "cuda":
|
|
||||||
x = x.cuda()
|
|
||||||
elif device != "cuda":
|
|
||||||
x = x.to(device)
|
|
||||||
return Variable(x, requires_grad=requires_grad, volatile=volatile)
|
|
||||||
|
|
||||||
|
|
||||||
def top_k_filter(logits, k, probs=False):
|
def top_k_filter(logits, k, probs=False):
|
||||||
"""
|
"""
|
||||||
Masks everything but the k top entries as -infinity (1e10).
|
Masks everything but the k top entries as -infinity (1e10).
|
||||||
@@ -156,9 +147,7 @@ def perturb_past(
|
|||||||
new_accumulated_hidden = None
|
new_accumulated_hidden = None
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
print("Iteration ", i + 1)
|
print("Iteration ", i + 1)
|
||||||
curr_perturbation = [
|
curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
|
||||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator
|
|
||||||
]
|
|
||||||
|
|
||||||
# Compute hidden using perturbed past
|
# Compute hidden using perturbed past
|
||||||
perturbed_past = list(map(add, past, curr_perturbation))
|
perturbed_past = list(map(add, past, curr_perturbation))
|
||||||
@@ -247,7 +236,7 @@ def perturb_past(
|
|||||||
past = new_past
|
past = new_past
|
||||||
|
|
||||||
# apply the accumulated perturbations to the past
|
# apply the accumulated perturbations to the past
|
||||||
grad_accumulator = [to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator]
|
grad_accumulator = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
|
||||||
pert_past = list(map(add, past, grad_accumulator))
|
pert_past = list(map(add, past, grad_accumulator))
|
||||||
|
|
||||||
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
||||||
@@ -266,7 +255,7 @@ def get_classifier(
|
|||||||
elif "path" in params:
|
elif "path" in params:
|
||||||
resolved_archive_file = params["path"]
|
resolved_archive_file = params["path"]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either url or path have to be specified " "in the discriminator model parameters")
|
raise ValueError("Either url or path have to be specified in the discriminator model parameters")
|
||||||
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
|
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
|
||||||
classifier.eval()
|
classifier.eval()
|
||||||
|
|
||||||
@@ -569,9 +558,9 @@ def generate_text_pplm(
|
|||||||
|
|
||||||
def set_generic_model_params(discrim_weights, discrim_meta):
|
def set_generic_model_params(discrim_weights, discrim_meta):
|
||||||
if discrim_weights is None:
|
if discrim_weights is None:
|
||||||
raise ValueError("When using a generic discriminator, " "discrim_weights need to be specified")
|
raise ValueError("When using a generic discriminator, discrim_weights need to be specified")
|
||||||
if discrim_meta is None:
|
if discrim_meta is None:
|
||||||
raise ValueError("When using a generic discriminator, " "discrim_meta need to be specified")
|
raise ValueError("When using a generic discriminator, discrim_meta need to be specified")
|
||||||
|
|
||||||
with open(discrim_meta, "r") as discrim_meta_file:
|
with open(discrim_meta, "r") as discrim_meta_file:
|
||||||
meta = json.load(discrim_meta_file)
|
meta = json.load(discrim_meta_file)
|
||||||
@@ -619,7 +608,7 @@ def run_pplm_example(
|
|||||||
|
|
||||||
if discrim is not None:
|
if discrim is not None:
|
||||||
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
|
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
|
||||||
print("discrim = {}, pretrained_model set " "to discriminator's = {}".format(discrim, pretrained_model))
|
print("discrim = {}, pretrained_model set to discriminator's = {}".format(discrim, pretrained_model))
|
||||||
|
|
||||||
# load pretrained model
|
# load pretrained model
|
||||||
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
|
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
|
||||||
@@ -706,7 +695,7 @@ def run_pplm_example(
|
|||||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||||
if word_id in bow_word_ids:
|
if word_id in bow_word_ids:
|
||||||
pert_gen_text += "{}{}{}".format(
|
pert_gen_text += "{}{}{}".format(
|
||||||
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL
|
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pert_gen_text += tokenizer.decode([word_id])
|
pert_gen_text += tokenizer.decode([word_id])
|
||||||
@@ -744,9 +733,11 @@ if __name__ == "__main__":
|
|||||||
"-B",
|
"-B",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Bags of words used for PPLM-BoW. "
|
help=(
|
||||||
|
"Bags of words used for PPLM-BoW. "
|
||||||
"Either a BOW id (see list in code) or a filepath. "
|
"Either a BOW id (see list in code) or a filepath. "
|
||||||
"Multiple BoWs separated by ;",
|
"Multiple BoWs separated by ;"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--discrim",
|
"--discrim",
|
||||||
@@ -756,9 +747,11 @@ if __name__ == "__main__":
|
|||||||
choices=("clickbait", "sentiment", "toxicity", "generic"),
|
choices=("clickbait", "sentiment", "toxicity", "generic"),
|
||||||
help="Discriminator to use",
|
help="Discriminator to use",
|
||||||
)
|
)
|
||||||
parser.add_argument("--discrim_weights", type=str, default=None, help="Weights for the generic discriminator")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator"
|
"--discrim_weights", type=str, default=None, help="Weights for the generic discriminator",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--class_label", type=int, default=-1, help="Class label used for the discriminator",
|
"--class_label", type=int, default=-1, help="Class label used for the discriminator",
|
||||||
@@ -774,7 +767,7 @@ if __name__ == "__main__":
|
|||||||
"--window_length",
|
"--window_length",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Length of past which is being optimized; " "0 corresponds to infinite window length",
|
help="Length of past which is being optimized; 0 corresponds to infinite window length",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--horizon_length", type=int, default=1, help="Length of future to optimize over",
|
"--horizon_length", type=int, default=1, help="Length of future to optimize over",
|
||||||
|
|||||||
Reference in New Issue
Block a user