diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 4b1a6a2b6f..2f853d15c1 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -1,3 +1,4 @@ +#! /usr/bin/env python3 # coding=utf-8 # Copyright 2018 The Uber AI Team Authors. # @@ -37,10 +38,12 @@ from transformers import GPT2Tokenizer from transformers.file_utils import cached_path from transformers.modeling_gpt2 import GPT2LMHeadModel + PPLM_BOW = 1 PPLM_DISCRIM = 2 PPLM_BOW_DISCRIM = 3 SMALL_CONST = 1e-15 +SmallConst = 1e-15 TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium") BAG_OF_WORDS_ARCHIVE_MAP = { @@ -65,7 +68,7 @@ DISCRIMINATOR_MODELS_PARAMS = { "default_class": 1, }, "sentiment": { - "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/sentiment_classifierhead.pt", + "url": "http://s.yosinski.com/SST_classifier_head.pt", "class_size": 5, "embed_size": 1024, "class_vocab": {"very_positive": 2, "very_negative": 3}, @@ -81,6 +84,30 @@ DISCRIMINATOR_MODELS_PARAMS = { } +def to_var(x, requires_grad=False, volatile=False): + if torch.cuda.is_available(): + x = x.cuda() + return Variable(x, requires_grad=requires_grad, volatile=volatile) + + +def top_k_filter(logits, k, probs=False): + """ + Masks everything but the k top entries as -infinity (1e10). + Used to mask logits such that e^-infinity -> 0 won't contribute to the + sum of the denominator. + """ + if k == 0: + return logits + else: + values = torch.topk(logits, k)[0] + batch_mins = values[:, -1].view(-1, 1).expand_as(logits) + if probs: + return torch.where(logits < batch_mins, + torch.ones_like(logits) * 0.0, logits) + return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, + logits) + + class ClassificationHead(torch.nn.Module): """ Classification Head for the transformer """ @@ -99,234 +126,175 @@ class ClassificationHead(torch.nn.Module): return logits -def to_var(x, requires_grad=False, volatile=False): - if torch.cuda.is_available(): - x = x.cuda() - return Variable(x, requires_grad=requires_grad, volatile=volatile) +def perturb_past(past, model, prev, args, classifier, good_index=None, + stepsize=0.01, vocab_size=50257, + original_probs=None, accumulated_hidden=None, true_past=None, + grad_norms=None): + window_length = args.window_length + gm_scale, kl_scale = args.gm_scale, args.kl_scale + one_hot_vectors = [] + for good_list in good_index: + good_list = list(filter(lambda x: len(x) <= 1, good_list)) + good_list = torch.tensor(good_list).cuda() + num_good = good_list.shape[0] + one_hot_good = torch.zeros(num_good, vocab_size).cuda() + one_hot_good.scatter_(1, good_list, 1) + one_hot_vectors.append(one_hot_good) - -def top_k_filter(logits, k, probs=False): - """ - Masks everything but the k top entries as -infinity (1e10). - Used to mask logits such that e^-infinity -> 0 won't contribute to the - sum of the denominator. - """ - if k <= 0: - return logits - - else: - values = torch.topk(logits, k)[0] - batch_mins = values[:, -1].view(-1, 1).expand_as(logits) - - if probs: - return torch.where( - logits < batch_mins, - torch.ones_like(logits) * 0.0, - logits - ) - - return torch.where( - logits < batch_mins, - torch.ones_like(logits) * -1e10, - logits - ) - - -def perturb_past( - past, - model, - last, - unpert_past=None, - unpert_logits=None, - accumulated_hidden=None, - grad_norms=None, - stepsize=0.01, - classifier=None, - label_class=None, - one_hot_bows_vectors=None, - loss_type=0, - num_iterations=3, - kl_scale=0.01, - window_length=0, - horizon_length=1, - decay=False, - gamma=1.5, -): - # initializie perturbation accumulator - grad_accumulator = [ - (np.zeros(p.shape).astype("float32")) - for p in past - ] + # Generate inital perturbed past + past_perturb_orig = [ + (np.random.uniform(0.0, 0.0, p.shape).astype('float32')) + for p in past] if accumulated_hidden is None: accumulated_hidden = 0 - if decay: - decay_mask = torch.arange( - 0.0, - 1.0 + SMALL_CONST, - 1.0 / (window_length) - )[1:] + if args.decay: + decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[ + 1:] else: decay_mask = 1.0 - # TODO fix this comment (SUMANTH) - # generate a mask if perturbated gradient is based on a past window - _, _, _, curr_length, _ = past[0].shape - if curr_length > window_length and window_length > 0: - ones_key_val_shape = ( - tuple(past[0].shape[:-2]) - + tuple([window_length]) - + tuple(past[0].shape[-1:]) - ) + # Generate a mask is gradient perturbated is based on a past window + _, _, _, current_length, _ = past[0].shape - zeros_key_val_shape = ( - tuple(past[0].shape[:-2]) - + tuple([curr_length - window_length]) - + tuple(past[0].shape[-1:]) - ) + if current_length > window_length and window_length > 0: + ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple( + [window_length]) + tuple( + past[0].shape[-1:]) + + zeros_key_val_shape = tuple(past[0].shape[:-2]) + tuple( + [current_length - window_length]) + tuple( + past[0].shape[-1:]) ones_mask = torch.ones(ones_key_val_shape) ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3) ones_mask = ones_mask.permute(0, 1, 2, 4, 3) - window_mask = torch.cat( - (ones_mask, torch.zeros(zeros_key_val_shape)), - dim=-2 - ).cuda() - + window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), + dim=-2).cuda() else: window_mask = torch.ones_like(past[0]).cuda() - # accumulate perturbations for num_iterations loss_per_iter = [] - for i in range(num_iterations): + for i in range(args.num_iterations): print("Iteration ", i + 1) + past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] + past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] - curr_perturbation = [ - to_var(torch.from_numpy(p_), requires_grad=True) - for p_ in grad_accumulator - ] + perturbed_past = list(map(add, past, past_perturb)) - # Compute hidden using perturbed past - curr_pert_past = list(map(add, past, curr_perturbation)) - all_logits, _, all_hidden = model(last, past=curr_pert_past) + _, _, _, current_length, _ = past_perturb[0].shape + + # _, future_past = model(prev, past=perturbed_past) + # hidden = model.hidden_states + + # Piero modified model call + logits, _, all_hidden = model(prev, past=perturbed_past) hidden = all_hidden[-1] - accumulated_hidden += torch.sum(hidden, dim=1).detach() - logits = all_logits[:, -1, :] - probs = F.softmax(logits, dim=-1) + new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, + dim=1).detach() - # compute loss - bow_loss = 0.0 - discrim_loss = 0.0 - kl_loss = 0.0 + # TODO: Check the layer-norm consistency of this with trained discriminator + logits = logits[:, -1, :] + probabs = F.softmax(logits, dim=-1) + loss = 0.0 + loss_list = [] + if args.loss_type == 1 or args.loss_type == 3: + for one_hot_good in one_hot_vectors: + good_logits = torch.mm(probabs, torch.t(one_hot_good)) + loss_word = good_logits + loss_word = torch.sum(loss_word) + loss_word = -torch.log(loss_word) + # loss_word = torch.sum(loss_word) /torch.sum(one_hot_good) + loss += loss_word + loss_list.append(loss_word) + print(" pplm_bow_loss:", loss.data.cpu().numpy()) - if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM: - for one_hot_bow in one_hot_bows_vectors: - bow_logits = torch.mm(probs, torch.t(one_hot_bow)) - bow_loss += -torch.log(torch.sum(bow_logits)) - print(" pplm_bow_loss:", bow_loss.data.cpu().numpy()) - - if loss_type == PPLM_DISCRIM or loss_type == PPLM_BOW_DISCRIM: + if args.loss_type == 2 or args.loss_type == 3: ce_loss = torch.nn.CrossEntropyLoss() - # TODO all there are for (SUMANTH) - # TODO why we need to do this assignment and not just using unpert_past? - curr_unpert_past = unpert_past - # Get the model's token embeddings in order to compute our own embeds from curr_probs: - wte = model.resize_token_embeddings() - # TODO i is never used, why do we need to do this i times instead multiplying - # torch.sum(unpert_hidden, dim=1) * horizon_length? - for i in range(horizon_length): - # TODO the next two lines can be done only one time, and why not using probs instead as they do not change at each iteration? - curr_probs = F.softmax(logits, dim=-1) # get softmax - curr_probs = torch.unsqueeze(curr_probs, dim=1) - inputs_embeds = torch.matmul(curr_probs, wte.weight.data) - _, curr_unpert_past, curr_all_hidden = model( - past=curr_unpert_past, + new_true_past = true_past + for i in range(args.horizon_length): + future_probabs = F.softmax(logits, dim=-1) # Get softmax + future_probabs = torch.unsqueeze(future_probabs, dim=1) + + # _, new_true_past = model(future_probabs, past=new_true_past) + # future_hidden = model.hidden_states # Get expected hidden states + + # Piero modified model call + wte = model.resize_token_embeddings() + inputs_embeds = torch.matmul(future_probabs, wte.weight.data) + _, new_true_past, future_hidden = model( + past=new_true_past, inputs_embeds=inputs_embeds ) - # get expected hidden states - unpert_hidden = curr_all_hidden[-1] - accumulated_hidden += torch.sum(unpert_hidden, dim=1).detach() + future_hidden = future_hidden[-1] - prediction = classifier( - accumulated_hidden / (curr_length + 1 + horizon_length) - ) + new_accumulated_hidden = new_accumulated_hidden + torch.sum( + future_hidden, dim=1) - label = torch.tensor([label_class], device="cuda", dtype=torch.long) - discrim_loss += ce_loss(prediction, label) + predicted_sentiment = classifier(new_accumulated_hidden / ( + current_length + 1 + args.horizon_length)) + + label = torch.tensor([args.label_class], device='cuda', + dtype=torch.long) + discrim_loss = ce_loss(predicted_sentiment, label) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) + loss += discrim_loss + loss_list.append(discrim_loss) - if kl_scale >= 0.0: - unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) - unpert_probs = ( - unpert_probs + SMALL_CONST * - (unpert_probs <= SMALL_CONST).type( - torch.FloatTensor - ).cuda().detach() - ) - - correction = SMALL_CONST * (probs <= SMALL_CONST).type( - torch.FloatTensor - ).cuda().detach() - corrected_probs = probs + correction.detach() + kl_loss = 0.0 + if kl_scale > 0.0: + p = (F.softmax(original_probs[:, -1, :], dim=-1)) + p = p + SmallConst * (p <= SmallConst).type( + torch.FloatTensor).cuda().detach() + correction = SmallConst * (probabs <= SmallConst).type( + torch.FloatTensor).cuda().detach() + corrected_probabs = probabs + correction.detach() kl_loss = kl_scale * ( - (corrected_probs * (corrected_probs / unpert_probs).log()).sum() - ) + (corrected_probabs * (corrected_probabs / p).log()).sum()) print(' kl_loss', (kl_loss).data.cpu().numpy()) + loss += kl_loss # + discrim_loss - loss = bow_loss + discrim_loss + kl_loss loss_per_iter.append(loss.data.cpu().numpy()) + print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) - # compute gradients loss.backward() - - # calculate gradient norms - if grad_norms is not None and loss_type == PPLM_BOW: + if grad_norms is not None and args.loss_type == 1: grad_norms = [ torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) - for index, p_ in enumerate(curr_perturbation) - ] + for index, p_ in + enumerate(past_perturb)] else: - grad_norms = [ - (torch.norm(p_.grad * window_mask) + SMALL_CONST) - for index, p_ in enumerate(curr_perturbation) - ] + grad_norms = [(torch.norm(p_.grad * window_mask) + SmallConst) for + index, p_ in enumerate(past_perturb)] - # normalize gradients grad = [ - -stepsize - * (p_.grad * window_mask / grad_norms[ - index] ** gamma).data.cpu().numpy() - for index, p_ in enumerate(curr_perturbation) - ] + -stepsize * (p_.grad * window_mask / grad_norms[ + index] ** args.gamma).data.cpu().numpy() + for index, p_ in enumerate(past_perturb)] + past_perturb_orig = list(map(add, grad, past_perturb_orig)) - # accumulate gradients - grad_accumulator = list(map(add, grad, grad_accumulator)) - - # reset gradients, just to make sure - for p_ in curr_perturbation: + for p_ in past_perturb: p_.grad.data.zero_() - # removing past from the graph new_past = [] - for p_ in past: - new_past.append(p_.detach()) + for p in past: + new_past.append(p.detach()) + past = new_past - # apply the accumulated perturbations to the past - grad_accumulator = [ - to_var(torch.from_numpy(p_), requires_grad=True) - for p_ in grad_accumulator - ] - pert_past = list(map(add, past, grad_accumulator)) + past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] + past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] + perturbed_past = list(map(add, past, past_perturb)) - return pert_past, accumulated_hidden, grad_norms, loss_per_iter + return perturbed_past, new_accumulated_hidden, grad_norms, loss_per_iter def get_classifier( - name: Optional[str], label_class: Union[str, int], device: Union[str, torch.device] + name: Optional[str], label_class: Union[str, int], + device: Union[str, torch.device] ) -> Tuple[Optional[ClassificationHead], Optional[int]]: if name is None: return None, None @@ -337,7 +305,8 @@ def get_classifier( embed_size=params['embed_size'] ).to(device) resolved_archive_file = cached_path(params["url"]) - classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device)) + classifier.load_state_dict( + torch.load(resolved_archive_file, map_location=device)) classifier.eval() if isinstance(label_class, str): @@ -364,7 +333,8 @@ def get_classifier( return classifier, label_id -def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[List[List[int]]]: +def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[ + List[List[int]]]: bow_indices = [] for id_or_path in bag_of_words_ids_or_paths: if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: @@ -372,8 +342,10 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[List[ else: filepath = id_or_path with open(filepath, "r") as f: - words = f.read().split("\n") - bow_indices.append([TOKENIZER.encode(word, add_prefix_space=True) for word in words]) + words = f.read().strip().split("\n") + bow_indices.append( + [TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in + words]) return bow_indices @@ -392,327 +364,308 @@ def build_bows_one_hot_vectors(bow_indices): return one_hot_bows_vectors -def full_text_generation( - model, - context=None, - num_samples=1, - device="cuda", - sample=True, - discrim=None, - label_class=None, - bag_of_words=None, - length=100, - grad_length=10000, - stepsize=0.02, - num_iterations=3, - temperature=1.0, - gm_scale=0.9, - kl_scale=0.01, - top_k=10, - window_length=0, - horizon_length=1, - decay=False, - gamma=1.5, - **kwargs -): +def latent_perturb(model, args, context=None, sample=True, device='cuda'): classifier, class_id = get_classifier( - discrim, - label_class, + args.discrim, + args.label_class, device ) - bow_indices = [] - if bag_of_words: - bow_indices = get_bag_of_words_indices(bag_of_words.split(";")) + # if args.discrim == 'clickbait': + # classifier = ClassificationHead(class_size=2, embed_size=1024).to(device) + # classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt")) + # classifier.eval() + # args.label_class = 1 # clickbaity + # + # elif args.discrim == 'sentiment': + # classifier = ClassificationHead(class_size=5, embed_size=1024).to(device) + # #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt")) + # classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt")) + # classifier.eval() + # if args.label_class < 0: + # raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*') + # #args.label_class = 2 # very pos + # #args.label_class = 3 # very neg + # + # elif args.discrim == 'toxicity': + # classifier = ClassificationHead(class_size=2, embed_size=1024).to(device) + # classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt")) + # classifier.eval() + # args.label_class = 0 # not toxic + # + # elif args.discrim == 'generic': + # if args.discrim_weights is None: + # raise ValueError('When using a generic discriminator, ' + # 'discrim_weights need to be specified') + # if args.discrim_meta is None: + # raise ValueError('When using a generic discriminator, ' + # 'discrim_meta need to be specified') + # + # with open(args.discrim_meta, 'r') as discrim_meta_file: + # meta = json.load(discrim_meta_file) + # + # classifier = ClassificationHead( + # class_size=meta['class_size'], + # embed_size=meta['embed_size'], + # # todo add tokenizer from meta + # ).to(device) + # classifier.load_state_dict(torch.load(args.discrim_weights)) + # classifier.eval() + # if args.label_class == -1: + # args.label_class = meta['default_class'] + # + # else: + # classifier = None - if bag_of_words and classifier: + # Get tokens for the list of positive words + def list_tokens(word_list): + token_list = [TOKENIZER.encode(word, add_prefix_space=True) for word in + word_list] + # token_list = [] + # for word in word_list: + # token_list.append(TOKENIZER.encode(" " + word)) + return token_list + + # good_index = [] + # if args.bag_of_words: + # bags_of_words = args.bag_of_words.split(";") + # for wordlist in bags_of_words: + # with open(wordlist, "r") as f: + # words = f.read().strip() + # words = words.split('\n') + # good_index.append(list_tokens(words)) + # + # for good_list in good_index: + # good_list = list(filter(lambda x: len(x) <= 1, good_list)) + # actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list] + + good_index = [] + actual_words = None + if args.bag_of_words: + good_index = get_bag_of_words_indices(args.bag_of_words.split(";")) + + for good_list in good_index: + good_list = list(filter(lambda x: len(x) <= 1, good_list)) + actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in + good_list] + + if args.bag_of_words and classifier: print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") - loss_type = PPLM_BOW_DISCRIM + args.loss_type = PPLM_BOW_DISCRIM - elif bag_of_words: - loss_type = PPLM_BOW + elif args.bag_of_words: + args.loss_type = PPLM_BOW print("Using PPLM-BoW") elif classifier is not None: - loss_type = PPLM_DISCRIM + args.loss_type = PPLM_DISCRIM print("Using PPLM-Discrim") else: raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)") - unpert_gen_tok_text, _, _ = generate_text_pplm( - model=model, - context=context, - device=device, - length=length, - perturb=False - ) + original, _, _ = sample_from_hidden(model=model, args=args, context=context, + device=device, + perturb=False, good_index=good_index, + classifier=classifier) torch.cuda.empty_cache() - pert_gen_tok_texts = [] - discrim_losses = [] - losses_in_time = [] + perturbed_list = [] + discrim_loss_list = [] + loss_in_time_list = [] - for i in range(num_samples): - pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm( - model=model, - context=context, - device=device, - sample=sample, - perturb=True, - bow_indices=bow_indices, - classifier=classifier, - label_class=class_id, - loss_type=loss_type, - length=length, - grad_length=grad_length, - stepsize=stepsize, - num_iterations=num_iterations, - temperature=temperature, - gm_scale=gm_scale, - kl_scale=kl_scale, - top_k=top_k, - window_length=window_length, - horizon_length=horizon_length, - decay=decay, - gamma=gamma, - ) - pert_gen_tok_texts.append(pert_gen_tok_text) + for i in range(args.num_samples): + perturbed, discrim_loss, loss_in_time = sample_from_hidden(model=model, + args=args, + context=context, + device=device, + perturb=True, + good_index=good_index, + classifier=classifier) + perturbed_list.append(perturbed) if classifier is not None: - discrim_losses.append(discrim_loss.data.cpu().numpy()) - losses_in_time.append(loss_in_time) + discrim_loss_list.append(discrim_loss.data.cpu().numpy()) + loss_in_time_list.append(loss_in_time) torch.cuda.empty_cache() - return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time + return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words -def generate_text_pplm( - model, - context=None, - past=None, - device="cuda", - sample=True, - perturb=True, - classifier=None, - label_class=None, - bow_indices=None, - loss_type=0, - length=100, - grad_length=10000, - stepsize=0.02, - num_iterations=3, - temperature=1.0, - gm_scale=0.9, - kl_scale=0.01, - top_k=10, - window_length=0, - horizon_length=1, - decay=False, - gamma=1.5, -): - output_so_far = ( - torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) - if context - else None - ) - - # collect one hot vectors for bags of words - one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices) +def sample_from_hidden(model, args, classifier, context=None, past=None, + device='cuda', + sample=True, perturb=True, good_index=None): + output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze( + 0) if context else None grad_norms = None - last = None - unpert_discrim_loss = 0 loss_in_time = [] - for i in trange(length, ascii=True): + for i in trange(args.length, ascii=True): # Get past/probs for current output, except for last word - # Note that GPT takes 2 inputs: past + current_token + # Note that GPT takes 2 inputs: past + current-token + # Therefore, use everything from before current i/p token to generate relevant past - # run model forward to obtain unperturbed - if past is None and output_so_far is not None: - last = output_so_far[:, -1:] - if output_so_far.shape[1] > 1: - _, past, _ = model(output_so_far[:, :-1]) + if past is None and output is not None: + prev = output[:, -1:] + # _, past = model(output[:, :-1]) + # original_probs, true_past = model(output) + # true_hidden = model.hidden_states - unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) - unpert_last_hidden = unpert_all_hidden[-1] + # Piero modified model call + _, past, _ = model(output[:, :-1]) + original_probs, true_past, unpert_all_hidden = model(output) + true_hidden = unpert_all_hidden[-1] else: - unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) - unpert_last_hidden = unpert_all_hidden[-1] + # original_probs, true_past = model(output) + # true_hidden = model.hidden_states - # check if we are abowe grad max length - if i >= grad_length: - current_stepsize = stepsize * 0 + # Piero modified model call + original_probs, true_past, unpert_all_hidden = model(output) + true_hidden = unpert_all_hidden[-1] + + # Modify the past if necessary + + if i >= args.grad_length: + current_stepsize = args.stepsize * 0 else: - current_stepsize = stepsize + current_stepsize = args.stepsize - # modify the past if necessary - if not perturb or num_iterations == 0: - pert_past = past + if not perturb or args.num_iterations == 0: + perturbed_past = past else: - accumulated_hidden = unpert_last_hidden[:, :-1, :] + # Piero modified model call + # accumulated_hidden = model.hidden_states[:, :-1, :] + accumulated_hidden = true_hidden[:, :-1, :] accumulated_hidden = torch.sum(accumulated_hidden, dim=1) - if past is not None: - pert_past, _, grad_norms, loss_this_iter = perturb_past( - past, - model, - last, - unpert_past=unpert_past, - unpert_logits=unpert_logits, - accumulated_hidden=accumulated_hidden, - grad_norms=grad_norms, - stepsize=current_stepsize, - classifier=classifier, - label_class=label_class, - one_hot_bows_vectors=one_hot_bows_vectors, - loss_type=loss_type, - num_iterations=num_iterations, - kl_scale=kl_scale, - window_length=window_length, - horizon_length=horizon_length, - decay=decay, - gamma=gamma, - ) - loss_in_time.append(loss_this_iter) - else: - pert_past = past + perturbed_past, _, grad_norms, loss_per_iter = perturb_past(past, + model, + prev, + args, + good_index=good_index, + stepsize=current_stepsize, + original_probs=original_probs, + true_past=true_past, + accumulated_hidden=accumulated_hidden, + classifier=classifier, + grad_norms=grad_norms) + loss_in_time.append(loss_per_iter) - pert_logits, past, pert_all_hidden = model(last, past=pert_past) - pert_logits = pert_logits[:, -1, :] / temperature - pert_probs = F.softmax(pert_logits, dim=-1) + # Piero modified model call + logits, past, pert_all_hidden = model(prev, past=perturbed_past) + # test_logits = F.softmax(test_logits[:, -1, :], dim=-1) + # likelywords = torch.topk(test_logits, k=10, dim=-1) + # print(TOKENIZER.decode(likelywords[1].tolist()[0])) - # compute the discriminator loss using unperturbed hidden if classifier is not None: - prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) - label = torch.tensor([label_class], device="cuda", dtype=torch.long) - unpert_discrim_loss = torch.nn.CrossEntropyLoss()(prediction, label) - print( - "unperturbed discrim loss", - unpert_discrim_loss.data.cpu().numpy() - ) + ce_loss = torch.nn.CrossEntropyLoss() + predicted_sentiment = classifier(torch.mean(true_hidden, dim=1)) + label = torch.tensor([args.label_class], device='cuda', + dtype=torch.long) + true_discrim_loss = ce_loss(predicted_sentiment, label) + print("true discrim loss", true_discrim_loss.data.cpu().numpy()) else: - unpert_discrim_loss = 0 + true_discrim_loss = 0 - # Fuse the modified model and original model probabilities + # Piero modified model call + # hidden = model.hidden_states # update hidden + # logits = model.forward_hidden(hidden) + logits = logits[:, -1, :] / args.temperature # + SmallConst + + # logits = top_k_filter(logits, k=args.top_k) # + SmallConst + + log_probs = F.softmax(logits, dim=-1) + + # Fuse the modified model and original model if perturb: - unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) - pert_probs = (pert_probs ** gm_scale) * ( - unpert_probs ** (1 - gm_scale) - ) + # original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst + original_probs = F.softmax(original_probs[:, -1, :], dim=-1) + # likelywords = torch.topk(original_probs, k=10, dim=-1) + # print(TOKENIZER.decode(likelywords[1].tolist()[0])) - pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) + gm_scale = args.gm_scale + log_probs = ((log_probs ** gm_scale) * ( + original_probs ** (1 - gm_scale))) # + SmallConst - # rescale - if torch.sum(pert_probs) <= 1: - pert_probs = pert_probs / torch.sum(pert_probs) + log_probs = top_k_filter(log_probs, k=args.top_k, + probs=True) # + SmallConst + + if torch.sum(log_probs) <= 1: + log_probs = log_probs / torch.sum(log_probs) else: - pert_logits = top_k_filter(pert_logits, k=top_k) - pert_probs = F.softmax(pert_logits, dim=-1) + logits = top_k_filter(logits, k=args.top_k) # + SmallConst + log_probs = F.softmax(logits, dim=-1) - # sample or greedy if sample: - last = torch.multinomial(pert_probs, num_samples=1) - + # likelywords = torch.topk(log_probs, k=args.top_k, dim=-1) + # print(TOKENIZER.decode(likelywords[1].tolist()[0])) + # print(likelywords[0].tolist()) + prev = torch.multinomial(log_probs, num_samples=1) else: - _, last = torch.topk(pert_probs, k=1, dim=-1) + _, prev = torch.topk(log_probs, k=1, dim=-1) + # if perturb: + # prev = future + output = prev if output is None else torch.cat((output, prev), + dim=1) # update output + print(TOKENIZER.decode(output.tolist()[0])) - # update context/output_so_far appending the new token - output_so_far = ( - last if output_so_far is None - else torch.cat((output_so_far, last), dim=1) - ) - print(TOKENIZER.decode(output_so_far.tolist()[0])) - - return output_so_far, unpert_discrim_loss, loss_in_time + return output, true_discrim_loss, loss_in_time def run_model(): parser = argparse.ArgumentParser() - parser.add_argument( - "--model_path", - "-M", - type=str, - default="gpt2-medium", - help="pretrained model name or path to local checkpoint", - ) - parser.add_argument( - "--bag_of_words", - "-B", - type=str, - default=None, - help="Bags of words used for PPLM-BoW. Either a BOW id (see list in code) or a filepath. Multiple BoWs separated by ;", - ) - parser.add_argument( - "--discrim", - "-D", - type=str, - default=None, - choices=("clickbait", "sentiment", "toxicity"), - help="Discriminator to use for loss-type 2", - ) - parser.add_argument( - "--label_class", - type=int, - default=-1, - help="Class label used for the discriminator", - ) - parser.add_argument("--stepsize", type=float, default=0.02) + parser.add_argument('--model_path', '-M', type=str, default='gpt2-medium', + help='pretrained model name or path to local checkpoint') + parser.add_argument('--bag-of-words', '-B', type=str, default=None, + help='Bags of words used for PPLM-BoW. Multiple BoWs separated by ;') + parser.add_argument('--discrim', '-D', type=str, default=None, + choices=( + 'clickbait', 'sentiment', 'toxicity', 'generic'), + help='Discriminator to use for loss-type 2') + parser.add_argument('--discrim_weights', type=str, default=None, + help='Weights for the generic discriminator') + parser.add_argument('--discrim_meta', type=str, default=None, + help='Meta information for the generic discriminator') + parser.add_argument('--label_class', type=int, default=-1, + help='Class label used for the discriminator') + parser.add_argument('--stepsize', type=float, default=0.02) parser.add_argument("--length", type=int, default=100) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=10) parser.add_argument("--gm_scale", type=float, default=0.9) parser.add_argument("--kl_scale", type=float, default=0.01) - parser.add_argument("--no_cuda", action="store_true", help="no cuda") - parser.add_argument( - "--uncond", action="store_true", - help="Generate from end-of-text as prefix" - ) - parser.add_argument( - "--cond_text", type=str, default="The lake", - help="Prefix texts to condition on" - ) - parser.add_argument("--num_iterations", type=int, default=3) - parser.add_argument("--grad_length", type=int, default=10000) - parser.add_argument( - "--num_samples", - type=int, - default=1, - help="Number of samples to generate from the modified latents", - ) - parser.add_argument( - "--horizon_length", - type=int, - default=1, - help="Length of future to optimize over", - ) - parser.add_argument( - "--window_length", - type=int, - default=0, - help="Length of past which is being optimized; " - "0 corresponds to infinite window length", - ) - parser.add_argument("--decay", action="store_true", - help="whether to decay or not") - parser.add_argument("--gamma", type=float, default=1.5) + parser.add_argument('--nocuda', action='store_true', help='no cuda') + parser.add_argument('--uncond', action='store_true', + help='Generate from end-of-text as prefix') + parser.add_argument("--cond_text", type=str, default='The lake', + help='Prefix texts to condition on') + parser.add_argument('--num_iterations', type=int, default=3) + parser.add_argument('--grad_length', type=int, default=10000) + parser.add_argument('--num_samples', type=int, default=1, + help='Number of samples to generate from the modified latents') + parser.add_argument('--horizon_length', type=int, default=1, + help='Length of future to optimize over') + # parser.add_argument('--force-token', action='store_true', help='no cuda') + parser.add_argument('--window_length', type=int, default=0, + help='Length of past which is being optimizer; 0 corresponds to infinite window length') + parser.add_argument('--decay', action='store_true', + help='whether to decay or not') + parser.add_argument('--gamma', type=float, default=1.5) + parser.add_argument('--colorama', action='store_true', help='no cuda') args = parser.parse_args() - # set Random seed torch.manual_seed(args.seed) np.random.seed(args.seed) - # set the device - device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + device = 'cpu' if args.nocuda else 'cuda' - # load pretrained model model = GPT2LMHeadModel.from_pretrained( args.model_path, output_hidden_states=True @@ -720,63 +673,82 @@ def run_model(): model.to(device) model.eval() - # freeze GPT-2 weights + # Freeze GPT-2 weights for param in model.parameters(): param.requires_grad = False + pass - # figure out conditioning text if args.uncond: - tokenized_cond_text = TOKENIZER.encode( - [TOKENIZER.bos_token] - ) + seq = [[50256, 50256]] + else: raw_text = args.cond_text while not raw_text: - print("Did you forget to add `--cond_text`? ") + print('Did you forget to add `--cond-text`? ') raw_text = input("Model prompt >>> ") - tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text) + seq = [[50256] + TOKENIZER.encode(raw_text)] - print("= Prefix of sentence =") - print(TOKENIZER.decode(tokenized_cond_text)) - print() + collect_gen = dict() + current_index = 0 + for out in seq: - # generate unperturbed and perturbed texts + text = TOKENIZER.decode(out) + print("=" * 40 + " Prefix of sentence " + "=" * 40) + print(text) + print("=" * 80) - # full_text_generation returns: - # unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time - unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation( - model=model, context=tokenized_cond_text, device=device, **vars(args) - ) + out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = latent_perturb( + model=model, args=args, context=out, + device=device) - # untokenize unperturbed text - unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0]) + text_whole = TOKENIZER.decode(out1.tolist()[0]) - print("=" * 80) - print("= Unperturbed generated text =") - print(unpert_gen_text) - print() + print("=" * 80) + print("=" * 40 + " Whole sentence (Original)" + "=" * 40) + print(text_whole) + print("=" * 80) - generated_texts = [] + out_perturb_copy = out_perturb - # iterate through the perturbed texts - for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): - try: - # untokenize unperturbed text - unpert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0]) + for out_perturb in out_perturb_copy: + # try: + # print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40) + # text_whole = TOKENIZER.decode(out_perturb.tolist()[0]) + # print(text_whole) + # print("=" * 80) + # except: + # pass + # collect_gen[current_index] = [out, out_perturb, out1] + ## Save the prefix, perturbed seq, original seq for each index + print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40) + keyword_tokens = [aa[-1][0] for aa in + actual_words] if actual_words else [] + output_tokens = out_perturb.tolist()[0] - print("= Perturbed generated text {} =".format(i + 1)) - print(unpert_gen_text) - print() - except: - pass + if args.colorama: + import colorama - # keep the prefix, perturbed seq, original seq for each index - generated_texts.append( - (tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text) - ) + text_whole = '' + for out in output_tokens: + if out in keyword_tokens: + text_whole += '%s%s%s' % ( + colorama.Fore.GREEN, TOKENIZER.decode([out]), + colorama.Style.RESET_ALL) + else: + text_whole += TOKENIZER.decode([out]) + else: + text_whole = TOKENIZER.decode(out_perturb.tolist()[0]) - return generated_texts + print(text_whole) + print("=" * 80) + + collect_gen[current_index] = [out, out_perturb, out1] + + current_index = current_index + 1 -if __name__ == "__main__": + return + + +if __name__ == '__main__': run_model()