fix the loss backward issue
(cherry picked from commit 566468cc984c6ec7e10dfc62b5b4191781a99cd2)
This commit is contained in:
committed by
Julien Chaumond
parent
572c24cfa2
commit
83b1e6ac9e
@@ -36,6 +36,7 @@ from tqdm import trange
|
|||||||
from transformers import GPT2Tokenizer
|
from transformers import GPT2Tokenizer
|
||||||
from transformers.file_utils import cached_path
|
from transformers.file_utils import cached_path
|
||||||
from transformers.modeling_gpt2 import GPT2LMHeadModel
|
from transformers.modeling_gpt2 import GPT2LMHeadModel
|
||||||
|
from IPython import embed
|
||||||
|
|
||||||
PPLM_BOW = 1
|
PPLM_BOW = 1
|
||||||
PPLM_DISCRIM = 2
|
PPLM_DISCRIM = 2
|
||||||
@@ -246,8 +247,8 @@ def perturb_past(
|
|||||||
inputs_embeds=inputs_embeds
|
inputs_embeds=inputs_embeds
|
||||||
)
|
)
|
||||||
# get expected hidden states
|
# get expected hidden states
|
||||||
unpert_hidden = curr_all_hidden[1]
|
unpert_hidden = curr_all_hidden[-1]
|
||||||
accumulated_hidden += torch.sum(unpert_hidden, dim=1)
|
accumulated_hidden += torch.sum(unpert_hidden, dim=1).detach()
|
||||||
|
|
||||||
prediction = classifier(
|
prediction = classifier(
|
||||||
accumulated_hidden / (curr_length + 1 + horizon_length)
|
accumulated_hidden / (curr_length + 1 + horizon_length)
|
||||||
@@ -257,7 +258,7 @@ def perturb_past(
|
|||||||
discrim_loss += ce_loss(prediction, label)
|
discrim_loss += ce_loss(prediction, label)
|
||||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||||
|
|
||||||
if kl_scale > 0.0:
|
if kl_scale >= 0.0:
|
||||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||||
unpert_probs = (
|
unpert_probs = (
|
||||||
unpert_probs + SMALL_CONST *
|
unpert_probs + SMALL_CONST *
|
||||||
@@ -270,7 +271,7 @@ def perturb_past(
|
|||||||
torch.FloatTensor
|
torch.FloatTensor
|
||||||
).cuda().detach()
|
).cuda().detach()
|
||||||
corrected_probs = probs + correction.detach()
|
corrected_probs = probs + correction.detach()
|
||||||
kl_loss += kl_scale * (
|
kl_loss = kl_scale * (
|
||||||
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
||||||
)
|
)
|
||||||
print(' kl_loss', (kl_loss).data.cpu().numpy())
|
print(' kl_loss', (kl_loss).data.cpu().numpy())
|
||||||
@@ -280,7 +281,7 @@ def perturb_past(
|
|||||||
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
||||||
|
|
||||||
# compute gradients
|
# compute gradients
|
||||||
loss.backward(retain_graph=True)
|
loss.backward()
|
||||||
|
|
||||||
# calculate gradient norms
|
# calculate gradient norms
|
||||||
if grad_norms is not None and loss_type == PPLM_BOW:
|
if grad_norms is not None and loss_type == PPLM_BOW:
|
||||||
|
|||||||
Reference in New Issue
Block a user