removed deprecared use of Variable api from pplm example
This commit is contained in:
committed by
Julien Chaumond
parent
12d0eb5f3e
commit
48a05026de
@@ -31,7 +31,6 @@ from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from tqdm import trange
|
||||
|
||||
from pplm_classification_head import ClassificationHead
|
||||
@@ -76,14 +75,6 @@ DISCRIMINATOR_MODELS_PARAMS = {
|
||||
}
|
||||
|
||||
|
||||
def to_var(x, requires_grad=False, volatile=False, device="cuda"):
|
||||
if torch.cuda.is_available() and device == "cuda":
|
||||
x = x.cuda()
|
||||
elif device != "cuda":
|
||||
x = x.to(device)
|
||||
return Variable(x, requires_grad=requires_grad, volatile=volatile)
|
||||
|
||||
|
||||
def top_k_filter(logits, k, probs=False):
|
||||
"""
|
||||
Masks everything but the k top entries as -infinity (1e10).
|
||||
@@ -156,9 +147,7 @@ def perturb_past(
|
||||
new_accumulated_hidden = None
|
||||
for i in range(num_iterations):
|
||||
print("Iteration ", i + 1)
|
||||
curr_perturbation = [
|
||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator
|
||||
]
|
||||
curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
|
||||
|
||||
# Compute hidden using perturbed past
|
||||
perturbed_past = list(map(add, past, curr_perturbation))
|
||||
@@ -247,7 +236,7 @@ def perturb_past(
|
||||
past = new_past
|
||||
|
||||
# apply the accumulated perturbations to the past
|
||||
grad_accumulator = [to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator]
|
||||
grad_accumulator = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
|
||||
pert_past = list(map(add, past, grad_accumulator))
|
||||
|
||||
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
|
||||
@@ -266,7 +255,7 @@ def get_classifier(
|
||||
elif "path" in params:
|
||||
resolved_archive_file = params["path"]
|
||||
else:
|
||||
raise ValueError("Either url or path have to be specified " "in the discriminator model parameters")
|
||||
raise ValueError("Either url or path have to be specified in the discriminator model parameters")
|
||||
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
|
||||
classifier.eval()
|
||||
|
||||
@@ -569,9 +558,9 @@ def generate_text_pplm(
|
||||
|
||||
def set_generic_model_params(discrim_weights, discrim_meta):
|
||||
if discrim_weights is None:
|
||||
raise ValueError("When using a generic discriminator, " "discrim_weights need to be specified")
|
||||
raise ValueError("When using a generic discriminator, discrim_weights need to be specified")
|
||||
if discrim_meta is None:
|
||||
raise ValueError("When using a generic discriminator, " "discrim_meta need to be specified")
|
||||
raise ValueError("When using a generic discriminator, discrim_meta need to be specified")
|
||||
|
||||
with open(discrim_meta, "r") as discrim_meta_file:
|
||||
meta = json.load(discrim_meta_file)
|
||||
@@ -619,7 +608,7 @@ def run_pplm_example(
|
||||
|
||||
if discrim is not None:
|
||||
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
|
||||
print("discrim = {}, pretrained_model set " "to discriminator's = {}".format(discrim, pretrained_model))
|
||||
print("discrim = {}, pretrained_model set to discriminator's = {}".format(discrim, pretrained_model))
|
||||
|
||||
# load pretrained model
|
||||
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
|
||||
@@ -706,7 +695,7 @@ def run_pplm_example(
|
||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||
if word_id in bow_word_ids:
|
||||
pert_gen_text += "{}{}{}".format(
|
||||
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL
|
||||
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL,
|
||||
)
|
||||
else:
|
||||
pert_gen_text += tokenizer.decode([word_id])
|
||||
@@ -744,9 +733,11 @@ if __name__ == "__main__":
|
||||
"-B",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Bags of words used for PPLM-BoW. "
|
||||
help=(
|
||||
"Bags of words used for PPLM-BoW. "
|
||||
"Either a BOW id (see list in code) or a filepath. "
|
||||
"Multiple BoWs separated by ;",
|
||||
"Multiple BoWs separated by ;"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim",
|
||||
@@ -756,9 +747,11 @@ if __name__ == "__main__":
|
||||
choices=("clickbait", "sentiment", "toxicity", "generic"),
|
||||
help="Discriminator to use",
|
||||
)
|
||||
parser.add_argument("--discrim_weights", type=str, default=None, help="Weights for the generic discriminator")
|
||||
parser.add_argument(
|
||||
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator"
|
||||
"--discrim_weights", type=str, default=None, help="Weights for the generic discriminator",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_label", type=int, default=-1, help="Class label used for the discriminator",
|
||||
@@ -774,7 +767,7 @@ if __name__ == "__main__":
|
||||
"--window_length",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Length of past which is being optimized; " "0 corresponds to infinite window length",
|
||||
help="Length of past which is being optimized; 0 corresponds to infinite window length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon_length", type=int, default=1, help="Length of future to optimize over",
|
||||
|
||||
Reference in New Issue
Block a user