[style] consistent nn. and nn.functional: part 4 examples (#12156)
* consistent nn. and nn.functional: p4 examples * restore
This commit is contained in:
@@ -1,19 +1,19 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ClassificationHead(torch.nn.Module):
|
||||
class ClassificationHead(nn.Module):
|
||||
"""Classification Head for transformer encoders"""
|
||||
|
||||
def __init__(self, class_size, embed_size):
|
||||
super().__init__()
|
||||
self.class_size = class_size
|
||||
self.embed_size = embed_size
|
||||
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
||||
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
||||
self.mlp = torch.nn.Linear(embed_size, class_size)
|
||||
# self.mlp1 = nn.Linear(embed_size, embed_size)
|
||||
# self.mlp2 = (nn.Linear(embed_size, class_size))
|
||||
self.mlp = nn.Linear(embed_size, class_size)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
# hidden_state = F.relu(self.mlp1(hidden_state))
|
||||
# hidden_state = nn.functional.relu(self.mlp1(hidden_state))
|
||||
# hidden_state = self.mlp2(hidden_state)
|
||||
logits = self.mlp(hidden_state)
|
||||
return logits
|
||||
|
||||
@@ -30,7 +30,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import trange
|
||||
|
||||
from pplm_classification_head import ClassificationHead
|
||||
@@ -160,7 +160,7 @@ def perturb_past(
|
||||
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
|
||||
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||
logits = all_logits[:, -1, :]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
probs = nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
loss = 0.0
|
||||
loss_list = []
|
||||
@@ -173,7 +173,7 @@ def perturb_past(
|
||||
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
||||
|
||||
if loss_type == 2 or loss_type == 3:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
|
||||
curr_unpert_past = unpert_past
|
||||
curr_probs = torch.unsqueeze(probs, dim=1)
|
||||
@@ -195,7 +195,7 @@ def perturb_past(
|
||||
|
||||
kl_loss = 0.0
|
||||
if kl_scale > 0.0:
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
||||
corrected_probs = probs + correction.detach()
|
||||
@@ -527,10 +527,10 @@ def generate_text_pplm(
|
||||
else:
|
||||
pert_logits[0, token_idx] /= repetition_penalty
|
||||
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
|
||||
|
||||
if classifier is not None:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||
label = torch.tensor([class_label], device=device, dtype=torch.long)
|
||||
unpert_discrim_loss = ce_loss(prediction, label)
|
||||
@@ -541,7 +541,7 @@ def generate_text_pplm(
|
||||
# Fuse the modified model and original model
|
||||
if perturb:
|
||||
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
|
||||
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
||||
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
|
||||
@@ -552,7 +552,7 @@ def generate_text_pplm(
|
||||
|
||||
else:
|
||||
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
|
||||
|
||||
# sample or greedy
|
||||
if sample:
|
||||
|
||||
@@ -23,10 +23,10 @@ import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data as data
|
||||
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
||||
from torch import nn
|
||||
from torchtext import data as torchtext_data
|
||||
from torchtext import datasets
|
||||
from tqdm import tqdm, trange
|
||||
@@ -42,7 +42,7 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
|
||||
max_length_seq = 100
|
||||
|
||||
|
||||
class Discriminator(torch.nn.Module):
|
||||
class Discriminator(nn.Module):
|
||||
"""Transformer encoder followed by a Classification Head"""
|
||||
|
||||
def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
|
||||
@@ -76,7 +76,7 @@ class Discriminator(torch.nn.Module):
|
||||
avg_hidden = self.avg_representation(x.to(self.device))
|
||||
|
||||
logits = self.classifier_head(avg_hidden)
|
||||
probs = F.log_softmax(logits, dim=-1)
|
||||
probs = nn.functional.log_softmax(logits, dim=-1)
|
||||
|
||||
return probs
|
||||
|
||||
@@ -140,7 +140,7 @@ def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10,
|
||||
optimizer.zero_grad()
|
||||
|
||||
output_t = discriminator(input_t)
|
||||
loss = F.nll_loss(output_t, target_t)
|
||||
loss = nn.functional.nll_loss(output_t, target_t)
|
||||
loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
|
||||
@@ -167,7 +167,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
|
||||
input_t, target_t = input_t.to(device), target_t.to(device)
|
||||
output_t = discriminator(input_t)
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
|
||||
test_loss += nn.functional.nll_loss(output_t, target_t, reduction="sum").item()
|
||||
# get the index of the max log-probability
|
||||
pred_t = output_t.argmax(dim=1, keepdim=True)
|
||||
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
|
||||
|
||||
Reference in New Issue
Block a user