[style] consistent nn. and nn.functional: part 4 examples (#12156)
* consistent nn. and nn.functional: p4 examples * restore
This commit is contained in:
@@ -30,7 +30,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import trange
|
||||
|
||||
from pplm_classification_head import ClassificationHead
|
||||
@@ -160,7 +160,7 @@ def perturb_past(
|
||||
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
|
||||
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||
logits = all_logits[:, -1, :]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
probs = nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
loss = 0.0
|
||||
loss_list = []
|
||||
@@ -173,7 +173,7 @@ def perturb_past(
|
||||
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
||||
|
||||
if loss_type == 2 or loss_type == 3:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
|
||||
curr_unpert_past = unpert_past
|
||||
curr_probs = torch.unsqueeze(probs, dim=1)
|
||||
@@ -195,7 +195,7 @@ def perturb_past(
|
||||
|
||||
kl_loss = 0.0
|
||||
if kl_scale > 0.0:
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
||||
corrected_probs = probs + correction.detach()
|
||||
@@ -527,10 +527,10 @@ def generate_text_pplm(
|
||||
else:
|
||||
pert_logits[0, token_idx] /= repetition_penalty
|
||||
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
|
||||
|
||||
if classifier is not None:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||
label = torch.tensor([class_label], device=device, dtype=torch.long)
|
||||
unpert_discrim_loss = ce_loss(prediction, label)
|
||||
@@ -541,7 +541,7 @@ def generate_text_pplm(
|
||||
# Fuse the modified model and original model
|
||||
if perturb:
|
||||
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
|
||||
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
||||
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
|
||||
@@ -552,7 +552,7 @@ def generate_text_pplm(
|
||||
|
||||
else:
|
||||
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
|
||||
|
||||
# sample or greedy
|
||||
if sample:
|
||||
|
||||
Reference in New Issue
Block a user