From 7469d03b1ca3c1c920e4cadc8a007609d17aff50 Mon Sep 17 00:00:00 2001 From: w4nderlust Date: Tue, 26 Nov 2019 21:30:57 -0800 Subject: [PATCH] Fixed minor bug when running training on cuda --- examples/run_pplm_discrim_train.py | 49 ++++++++++++++++-------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/examples/run_pplm_discrim_train.py b/examples/run_pplm_discrim_train.py index cc52234281..7f10e861a8 100644 --- a/examples/run_pplm_discrim_train.py +++ b/examples/run_pplm_discrim_train.py @@ -18,6 +18,7 @@ import torch.utils.data as data from nltk.tokenize.treebank import TreebankWordDetokenizer from torchtext import data as torchtext_data from torchtext import datasets + from transformers import GPT2Tokenizer, GPT2LMHeadModel torch.manual_seed(0) @@ -89,7 +90,7 @@ class Discriminator(torch.nn.Module): if self.cached_mode: avg_hidden = x.to(device) else: - avg_hidden = self.avg_representation(x) + avg_hidden = self.avg_representation(x.to(device)) logits = self.classifier_head(avg_hidden) probs = F.log_softmax(logits, dim=-1) @@ -203,7 +204,7 @@ def evaluate_performance(data_loader, discriminator): def predict(input_sentence, model, classes, cached=False): input_t = model.tokenizer.encode(input_sentence) - input_t = torch.tensor([input_t], dtype=torch.long) + input_t = torch.tensor([input_t], dtype=torch.long, device=device) if cached: input_t = model.avg_representation(input_t) @@ -428,7 +429,8 @@ def train_discriminator( with open(dataset_fp) as f: csv_reader = csv.reader(f, delimiter='\t') for row in csv_reader: - classes.add(row[0]) + if row: + classes.add(row[0]) idx2class = sorted(classes) class2idx = {c: i for i, c in enumerate(idx2class)} @@ -444,30 +446,31 @@ def train_discriminator( with open(dataset_fp) as f: csv_reader = csv.reader(f, delimiter='\t') for i, row in enumerate(csv_reader): - label = row[0] - text = row[1] + if row: + label = row[0] + text = row[1] - try: - seq = discriminator.tokenizer.encode(text) - if (len(seq) < max_length_seq): - seq = torch.tensor( - [50256] + seq, - device=device, - dtype=torch.long - ) + try: + seq = discriminator.tokenizer.encode(text) + if (len(seq) < max_length_seq): + seq = torch.tensor( + [50256] + seq, + device=device, + dtype=torch.long + ) - else: - print("Line {} is longer than maximum length {}".format( - i, max_length_seq - )) - continue + else: + print("Line {} is longer than maximum length {}".format( + i, max_length_seq + )) + continue - x.append(seq) - y.append(class2idx[label]) + x.append(seq) + y.append(class2idx[label]) - except: - print("Error tokenizing line {}, skipping it".format(i)) - pass + except: + print("Error tokenizing line {}, skipping it".format(i)) + pass full_dataset = Dataset(x, y) train_size = int(0.9 * len(full_dataset))