[style] consistent nn. and nn.functional: part 4 examples (#12156)

* consistent nn. and nn.functional: p4 examples

* restore
This commit is contained in:
Stas Bekman
2021-06-14 12:28:24 -07:00
committed by GitHub
parent 372ab9cd6d
commit 88e84186e5
26 changed files with 130 additions and 126 deletions

View File

@@ -30,7 +30,7 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import trange
from pplm_classification_head import ClassificationHead
@@ -160,7 +160,7 @@ def perturb_past(
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
probs = nn.functional.softmax(logits, dim=-1)
loss = 0.0
loss_list = []
@@ -173,7 +173,7 @@ def perturb_past(
print(" pplm_bow_loss:", loss.data.cpu().numpy())
if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss()
ce_loss = nn.CrossEntropyLoss()
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
curr_unpert_past = unpert_past
curr_probs = torch.unsqueeze(probs, dim=1)
@@ -195,7 +195,7 @@ def perturb_past(
kl_loss = 0.0
if kl_scale > 0.0:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
corrected_probs = probs + correction.detach()
@@ -527,10 +527,10 @@ def generate_text_pplm(
else:
pert_logits[0, token_idx] /= repetition_penalty
pert_probs = F.softmax(pert_logits, dim=-1)
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss()
ce_loss = nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([class_label], device=device, dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label)
@@ -541,7 +541,7 @@ def generate_text_pplm(
# Fuse the modified model and original model
if perturb:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
@@ -552,7 +552,7 @@ def generate_text_pplm(
else:
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
pert_probs = F.softmax(pert_logits, dim=-1)
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
# sample or greedy
if sample: