[style] consistent nn. and nn.functional: part 4 examples (#12156)

* consistent nn. and nn.functional: p4 examples

* restore
This commit is contained in:
Stas Bekman
2021-06-14 12:28:24 -07:00
committed by GitHub
parent 372ab9cd6d
commit 88e84186e5
26 changed files with 130 additions and 126 deletions

View File

@@ -1,19 +1,19 @@
import torch
from torch import nn
class ClassificationHead(torch.nn.Module):
class ClassificationHead(nn.Module):
"""Classification Head for transformer encoders"""
def __init__(self, class_size, embed_size):
super().__init__()
self.class_size = class_size
self.embed_size = embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self.mlp = torch.nn.Linear(embed_size, class_size)
# self.mlp1 = nn.Linear(embed_size, embed_size)
# self.mlp2 = (nn.Linear(embed_size, class_size))
self.mlp = nn.Linear(embed_size, class_size)
def forward(self, hidden_state):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = nn.functional.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits = self.mlp(hidden_state)
return logits

View File

@@ -30,7 +30,7 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import trange
from pplm_classification_head import ClassificationHead
@@ -160,7 +160,7 @@ def perturb_past(
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
probs = nn.functional.softmax(logits, dim=-1)
loss = 0.0
loss_list = []
@@ -173,7 +173,7 @@ def perturb_past(
print(" pplm_bow_loss:", loss.data.cpu().numpy())
if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss()
ce_loss = nn.CrossEntropyLoss()
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
curr_unpert_past = unpert_past
curr_probs = torch.unsqueeze(probs, dim=1)
@@ -195,7 +195,7 @@ def perturb_past(
kl_loss = 0.0
if kl_scale > 0.0:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
corrected_probs = probs + correction.detach()
@@ -527,10 +527,10 @@ def generate_text_pplm(
else:
pert_logits[0, token_idx] /= repetition_penalty
pert_probs = F.softmax(pert_logits, dim=-1)
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss()
ce_loss = nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([class_label], device=device, dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label)
@@ -541,7 +541,7 @@ def generate_text_pplm(
# Fuse the modified model and original model
if perturb:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
@@ -552,7 +552,7 @@ def generate_text_pplm(
else:
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
pert_probs = F.softmax(pert_logits, dim=-1)
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
# sample or greedy
if sample:

View File

@@ -23,10 +23,10 @@ import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer
from torch import nn
from torchtext import data as torchtext_data
from torchtext import datasets
from tqdm import tqdm, trange
@@ -42,7 +42,7 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
max_length_seq = 100
class Discriminator(torch.nn.Module):
class Discriminator(nn.Module):
"""Transformer encoder followed by a Classification Head"""
def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
@@ -76,7 +76,7 @@ class Discriminator(torch.nn.Module):
avg_hidden = self.avg_representation(x.to(self.device))
logits = self.classifier_head(avg_hidden)
probs = F.log_softmax(logits, dim=-1)
probs = nn.functional.log_softmax(logits, dim=-1)
return probs
@@ -140,7 +140,7 @@ def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10,
optimizer.zero_grad()
output_t = discriminator(input_t)
loss = F.nll_loss(output_t, target_t)
loss = nn.functional.nll_loss(output_t, target_t)
loss.backward(retain_graph=True)
optimizer.step()
@@ -167,7 +167,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
input_t, target_t = input_t.to(device), target_t.to(device)
output_t = discriminator(input_t)
# sum up batch loss
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
test_loss += nn.functional.nll_loss(output_t, target_t, reduction="sum").item()
# get the index of the max log-probability
pred_t = output_t.argmax(dim=1, keepdim=True)
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()