From 48a05026def1e94ae08037a252472c030409857e Mon Sep 17 00:00:00 2001 From: prajjwal1 Date: Thu, 28 May 2020 00:09:25 +0900 Subject: [PATCH] removed deprecared use of Variable api from pplm example --- examples/text-generation/pplm/run_pplm.py | 41 ++++++++++------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/examples/text-generation/pplm/run_pplm.py b/examples/text-generation/pplm/run_pplm.py index 73f2c3a6f6..abdd28c9af 100644 --- a/examples/text-generation/pplm/run_pplm.py +++ b/examples/text-generation/pplm/run_pplm.py @@ -31,7 +31,6 @@ from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F -from torch.autograd import Variable from tqdm import trange from pplm_classification_head import ClassificationHead @@ -76,14 +75,6 @@ DISCRIMINATOR_MODELS_PARAMS = { } -def to_var(x, requires_grad=False, volatile=False, device="cuda"): - if torch.cuda.is_available() and device == "cuda": - x = x.cuda() - elif device != "cuda": - x = x.to(device) - return Variable(x, requires_grad=requires_grad, volatile=volatile) - - def top_k_filter(logits, k, probs=False): """ Masks everything but the k top entries as -infinity (1e10). @@ -156,9 +147,7 @@ def perturb_past( new_accumulated_hidden = None for i in range(num_iterations): print("Iteration ", i + 1) - curr_perturbation = [ - to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator - ] + curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator] # Compute hidden using perturbed past perturbed_past = list(map(add, past, curr_perturbation)) @@ -247,7 +236,7 @@ def perturb_past( past = new_past # apply the accumulated perturbations to the past - grad_accumulator = [to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator] + grad_accumulator = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator] pert_past = list(map(add, past, grad_accumulator)) return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter @@ -266,7 +255,7 @@ def get_classifier( elif "path" in params: resolved_archive_file = params["path"] else: - raise ValueError("Either url or path have to be specified " "in the discriminator model parameters") + raise ValueError("Either url or path have to be specified in the discriminator model parameters") classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device)) classifier.eval() @@ -569,9 +558,9 @@ def generate_text_pplm( def set_generic_model_params(discrim_weights, discrim_meta): if discrim_weights is None: - raise ValueError("When using a generic discriminator, " "discrim_weights need to be specified") + raise ValueError("When using a generic discriminator, discrim_weights need to be specified") if discrim_meta is None: - raise ValueError("When using a generic discriminator, " "discrim_meta need to be specified") + raise ValueError("When using a generic discriminator, discrim_meta need to be specified") with open(discrim_meta, "r") as discrim_meta_file: meta = json.load(discrim_meta_file) @@ -619,7 +608,7 @@ def run_pplm_example( if discrim is not None: pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"] - print("discrim = {}, pretrained_model set " "to discriminator's = {}".format(discrim, pretrained_model)) + print("discrim = {}, pretrained_model set to discriminator's = {}".format(discrim, pretrained_model)) # load pretrained model model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True) @@ -706,7 +695,7 @@ def run_pplm_example( for word_id in pert_gen_tok_text.tolist()[0]: if word_id in bow_word_ids: pert_gen_text += "{}{}{}".format( - colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL + colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL, ) else: pert_gen_text += tokenizer.decode([word_id]) @@ -744,9 +733,11 @@ if __name__ == "__main__": "-B", type=str, default=None, - help="Bags of words used for PPLM-BoW. " - "Either a BOW id (see list in code) or a filepath. " - "Multiple BoWs separated by ;", + help=( + "Bags of words used for PPLM-BoW. " + "Either a BOW id (see list in code) or a filepath. " + "Multiple BoWs separated by ;" + ), ) parser.add_argument( "--discrim", @@ -756,9 +747,11 @@ if __name__ == "__main__": choices=("clickbait", "sentiment", "toxicity", "generic"), help="Discriminator to use", ) - parser.add_argument("--discrim_weights", type=str, default=None, help="Weights for the generic discriminator") parser.add_argument( - "--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator" + "--discrim_weights", type=str, default=None, help="Weights for the generic discriminator", + ) + parser.add_argument( + "--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator", ) parser.add_argument( "--class_label", type=int, default=-1, help="Class label used for the discriminator", @@ -774,7 +767,7 @@ if __name__ == "__main__": "--window_length", type=int, default=0, - help="Length of past which is being optimized; " "0 corresponds to infinite window length", + help="Length of past which is being optimized; 0 corresponds to infinite window length", ) parser.add_argument( "--horizon_length", type=int, default=1, help="Length of future to optimize over",