diff --git a/examples/run_pplm.py b/examples/run_pplm.py index b18e97b5a8..4d335a9241 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -43,7 +43,7 @@ PPLM_BOW = 1 PPLM_DISCRIM = 2 PPLM_BOW_DISCRIM = 3 SMALL_CONST = 1e-15 -SmallConst = 1e-15 +BIG_CONST = 1e10 TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium") BAG_OF_WORDS_ARCHIVE_MAP = { @@ -104,7 +104,8 @@ def top_k_filter(logits, k, probs=False): if probs: return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits) - return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, + return torch.where(logits < batch_mins, + torch.ones_like(logits) * -BIG_CONST, logits) @@ -137,7 +138,7 @@ def perturb_past( accumulated_hidden = 0 if decay: - decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[ + decay_mask = torch.arange(0., 1.0 + SMALL_CONST, 1.0 / (window_length))[ 1:] else: decay_mask = 1.0 @@ -233,9 +234,9 @@ def perturb_past( kl_loss = 0.0 if kl_scale > 0.0: p = (F.softmax(unpert_logits[:, -1, :], dim=-1)) - p = p + SmallConst * (p <= SmallConst).type( + p = p + SMALL_CONST * (p <= SMALL_CONST).type( torch.FloatTensor).cuda().detach() - correction = SmallConst * (probabs <= SmallConst).type( + correction = SMALL_CONST * (probabs <= SMALL_CONST).type( torch.FloatTensor).cuda().detach() corrected_probabs = probabs + correction.detach() kl_loss = kl_scale * ( @@ -254,7 +255,7 @@ def perturb_past( for index, p_ in enumerate(past_perturb)] else: - grad_norms = [(torch.norm(p_.grad * window_mask) + SmallConst) for + grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST) for index, p_ in enumerate(past_perturb)] grad = [ @@ -560,31 +561,31 @@ def generate_text_pplm( # Piero modified model call # hidden = model.hidden_states # update hidden # logits = model.forward_hidden(hidden) - logits = logits[:, -1, :] / temperature # + SmallConst + logits = logits[:, -1, :] / temperature # + SMALL_CONST - # logits = top_k_filter(logits, k=args.top_k) # + SmallConst + # logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST log_probs = F.softmax(logits, dim=-1) # Fuse the modified model and original model if perturb: - # original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst + # original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1) # likelywords = torch.topk(original_probs, k=10, dim=-1) # print(TOKENIZER.decode(likelywords[1].tolist()[0])) log_probs = ((log_probs ** gm_scale) * ( - unpert_logits ** (1 - gm_scale))) # + SmallConst + unpert_logits ** (1 - gm_scale))) # + SMALL_CONST log_probs = top_k_filter(log_probs, k=top_k, - probs=True) # + SmallConst + probs=True) # + SMALL_CONST if torch.sum(log_probs) <= 1: log_probs = log_probs / torch.sum(log_probs) else: - logits = top_k_filter(logits, k=top_k) # + SmallConst + logits = top_k_filter(logits, k=top_k) # + SMALL_CONST log_probs = F.softmax(logits, dim=-1) if sample: