Added script for training a discriminator for pplm to use
This commit is contained in:
@@ -34,6 +34,7 @@ import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from tqdm import trange
|
||||
|
||||
from examples.run_pplm_discrim_train import ClassificationHead
|
||||
from transformers import GPT2Tokenizer
|
||||
from transformers.file_utils import cached_path
|
||||
from transformers.modeling_gpt2 import GPT2LMHeadModel
|
||||
@@ -108,24 +109,6 @@ def top_k_filter(logits, k, probs=False):
|
||||
logits)
|
||||
|
||||
|
||||
class ClassificationHead(torch.nn.Module):
|
||||
""" Classification Head for the transformer """
|
||||
|
||||
def __init__(self, class_size=5, embed_size=2048):
|
||||
super(ClassificationHead, self).__init__()
|
||||
self.class_size = class_size
|
||||
self.embed_size = embed_size
|
||||
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
||||
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
||||
self.mlp = torch.nn.Linear(embed_size, class_size)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
# hidden_state = F.relu(self.mlp1(hidden_state))
|
||||
# hidden_state = self.mlp2(hidden_state)
|
||||
logits = self.mlp(hidden_state)
|
||||
return logits
|
||||
|
||||
|
||||
def perturb_past(past, model, prev, args, classifier, good_index=None,
|
||||
stepsize=0.01, vocab_size=50257,
|
||||
original_probs=None, accumulated_hidden=None, true_past=None,
|
||||
|
||||
Reference in New Issue
Block a user