Renamed SmallConst to SMALL_CONST and introduced BIG_CONST. Identical output as before.
This commit is contained in:
@@ -43,7 +43,7 @@ PPLM_BOW = 1
|
|||||||
PPLM_DISCRIM = 2
|
PPLM_DISCRIM = 2
|
||||||
PPLM_BOW_DISCRIM = 3
|
PPLM_BOW_DISCRIM = 3
|
||||||
SMALL_CONST = 1e-15
|
SMALL_CONST = 1e-15
|
||||||
SmallConst = 1e-15
|
BIG_CONST = 1e10
|
||||||
TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium")
|
TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium")
|
||||||
|
|
||||||
BAG_OF_WORDS_ARCHIVE_MAP = {
|
BAG_OF_WORDS_ARCHIVE_MAP = {
|
||||||
@@ -104,7 +104,8 @@ def top_k_filter(logits, k, probs=False):
|
|||||||
if probs:
|
if probs:
|
||||||
return torch.where(logits < batch_mins,
|
return torch.where(logits < batch_mins,
|
||||||
torch.ones_like(logits) * 0.0, logits)
|
torch.ones_like(logits) * 0.0, logits)
|
||||||
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10,
|
return torch.where(logits < batch_mins,
|
||||||
|
torch.ones_like(logits) * -BIG_CONST,
|
||||||
logits)
|
logits)
|
||||||
|
|
||||||
|
|
||||||
@@ -137,7 +138,7 @@ def perturb_past(
|
|||||||
accumulated_hidden = 0
|
accumulated_hidden = 0
|
||||||
|
|
||||||
if decay:
|
if decay:
|
||||||
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
|
decay_mask = torch.arange(0., 1.0 + SMALL_CONST, 1.0 / (window_length))[
|
||||||
1:]
|
1:]
|
||||||
else:
|
else:
|
||||||
decay_mask = 1.0
|
decay_mask = 1.0
|
||||||
@@ -233,9 +234,9 @@ def perturb_past(
|
|||||||
kl_loss = 0.0
|
kl_loss = 0.0
|
||||||
if kl_scale > 0.0:
|
if kl_scale > 0.0:
|
||||||
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
|
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
|
||||||
p = p + SmallConst * (p <= SmallConst).type(
|
p = p + SMALL_CONST * (p <= SMALL_CONST).type(
|
||||||
torch.FloatTensor).cuda().detach()
|
torch.FloatTensor).cuda().detach()
|
||||||
correction = SmallConst * (probabs <= SmallConst).type(
|
correction = SMALL_CONST * (probabs <= SMALL_CONST).type(
|
||||||
torch.FloatTensor).cuda().detach()
|
torch.FloatTensor).cuda().detach()
|
||||||
corrected_probabs = probabs + correction.detach()
|
corrected_probabs = probabs + correction.detach()
|
||||||
kl_loss = kl_scale * (
|
kl_loss = kl_scale * (
|
||||||
@@ -254,7 +255,7 @@ def perturb_past(
|
|||||||
for index, p_ in
|
for index, p_ in
|
||||||
enumerate(past_perturb)]
|
enumerate(past_perturb)]
|
||||||
else:
|
else:
|
||||||
grad_norms = [(torch.norm(p_.grad * window_mask) + SmallConst) for
|
grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST) for
|
||||||
index, p_ in enumerate(past_perturb)]
|
index, p_ in enumerate(past_perturb)]
|
||||||
|
|
||||||
grad = [
|
grad = [
|
||||||
@@ -560,31 +561,31 @@ def generate_text_pplm(
|
|||||||
# Piero modified model call
|
# Piero modified model call
|
||||||
# hidden = model.hidden_states # update hidden
|
# hidden = model.hidden_states # update hidden
|
||||||
# logits = model.forward_hidden(hidden)
|
# logits = model.forward_hidden(hidden)
|
||||||
logits = logits[:, -1, :] / temperature # + SmallConst
|
logits = logits[:, -1, :] / temperature # + SMALL_CONST
|
||||||
|
|
||||||
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
|
# logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST
|
||||||
|
|
||||||
log_probs = F.softmax(logits, dim=-1)
|
log_probs = F.softmax(logits, dim=-1)
|
||||||
|
|
||||||
# Fuse the modified model and original model
|
# Fuse the modified model and original model
|
||||||
if perturb:
|
if perturb:
|
||||||
|
|
||||||
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
|
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST
|
||||||
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||||
# likelywords = torch.topk(original_probs, k=10, dim=-1)
|
# likelywords = torch.topk(original_probs, k=10, dim=-1)
|
||||||
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
|
||||||
|
|
||||||
log_probs = ((log_probs ** gm_scale) * (
|
log_probs = ((log_probs ** gm_scale) * (
|
||||||
unpert_logits ** (1 - gm_scale))) # + SmallConst
|
unpert_logits ** (1 - gm_scale))) # + SMALL_CONST
|
||||||
|
|
||||||
log_probs = top_k_filter(log_probs, k=top_k,
|
log_probs = top_k_filter(log_probs, k=top_k,
|
||||||
probs=True) # + SmallConst
|
probs=True) # + SMALL_CONST
|
||||||
|
|
||||||
if torch.sum(log_probs) <= 1:
|
if torch.sum(log_probs) <= 1:
|
||||||
log_probs = log_probs / torch.sum(log_probs)
|
log_probs = log_probs / torch.sum(log_probs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logits = top_k_filter(logits, k=top_k) # + SmallConst
|
logits = top_k_filter(logits, k=top_k) # + SMALL_CONST
|
||||||
log_probs = F.softmax(logits, dim=-1)
|
log_probs = F.softmax(logits, dim=-1)
|
||||||
|
|
||||||
if sample:
|
if sample:
|
||||||
|
|||||||
Reference in New Issue
Block a user