Renamed SmallConst to SMALL_CONST and introduced BIG_CONST. Identical output as before.
This commit is contained in:
@@ -43,7 +43,7 @@ PPLM_BOW = 1
|
||||
PPLM_DISCRIM = 2
|
||||
PPLM_BOW_DISCRIM = 3
|
||||
SMALL_CONST = 1e-15
|
||||
SmallConst = 1e-15
|
||||
BIG_CONST = 1e10
|
||||
TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium")
|
||||
|
||||
BAG_OF_WORDS_ARCHIVE_MAP = {
|
||||
@@ -104,7 +104,8 @@ def top_k_filter(logits, k, probs=False):
|
||||
if probs:
|
||||
return torch.where(logits < batch_mins,
|
||||
torch.ones_like(logits) * 0.0, logits)
|
||||
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10,
|
||||
return torch.where(logits < batch_mins,
|
||||
torch.ones_like(logits) * -BIG_CONST,
|
||||
logits)
|
||||
|
||||
|
||||
@@ -137,7 +138,7 @@ def perturb_past(
|
||||
accumulated_hidden = 0
|
||||
|
||||
if decay:
|
||||
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
|
||||
decay_mask = torch.arange(0., 1.0 + SMALL_CONST, 1.0 / (window_length))[
|
||||
1:]
|
||||
else:
|
||||
decay_mask = 1.0
|
||||
@@ -233,9 +234,9 @@ def perturb_past(
|
||||
kl_loss = 0.0
|
||||
if kl_scale > 0.0:
|
||||
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
|
||||
p = p + SmallConst * (p <= SmallConst).type(
|
||||
p = p + SMALL_CONST * (p <= SMALL_CONST).type(
|
||||
torch.FloatTensor).cuda().detach()
|
||||
correction = SmallConst * (probabs <= SmallConst).type(
|
||||
correction = SMALL_CONST * (probabs <= SMALL_CONST).type(
|
||||
torch.FloatTensor).cuda().detach()
|
||||
corrected_probabs = probabs + correction.detach()
|
||||
kl_loss = kl_scale * (
|
||||
@@ -254,7 +255,7 @@ def perturb_past(
|
||||
for index, p_ in
|
||||
enumerate(past_perturb)]
|
||||
else:
|
||||
grad_norms = [(torch.norm(p_.grad * window_mask) + SmallConst) for
|
||||
grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST) for
|
||||
index, p_ in enumerate(past_perturb)]
|
||||
|
||||
grad = [
|
||||
@@ -560,31 +561,31 @@ def generate_text_pplm(
|
||||
# Piero modified model call
|
||||
# hidden = model.hidden_states # update hidden
|
||||
# logits = model.forward_hidden(hidden)
|
||||
logits = logits[:, -1, :] / temperature # + SmallConst
|
||||
logits = logits[:, -1, :] / temperature # + SMALL_CONST
|
||||
|
||||
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
|
||||
# logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST
|
||||
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
|
||||
# Fuse the modified model and original model
|
||||
if perturb:
|
||||
|
||||
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
|
||||
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST
|
||||
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
# likelywords = torch.topk(original_probs, k=10, dim=-1)
|
||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
||||
|
||||
log_probs = ((log_probs ** gm_scale) * (
|
||||
unpert_logits ** (1 - gm_scale))) # + SmallConst
|
||||
unpert_logits ** (1 - gm_scale))) # + SMALL_CONST
|
||||
|
||||
log_probs = top_k_filter(log_probs, k=top_k,
|
||||
probs=True) # + SmallConst
|
||||
probs=True) # + SMALL_CONST
|
||||
|
||||
if torch.sum(log_probs) <= 1:
|
||||
log_probs = log_probs / torch.sum(log_probs)
|
||||
|
||||
else:
|
||||
logits = top_k_filter(logits, k=top_k) # + SmallConst
|
||||
logits = top_k_filter(logits, k=top_k) # + SMALL_CONST
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
|
||||
if sample:
|
||||
|
||||
Reference in New Issue
Block a user