From 611961ade71042ff759712e6f680544ec5ff68b9 Mon Sep 17 00:00:00 2001 From: piero Date: Wed, 27 Nov 2019 21:34:49 -0800 Subject: [PATCH] Added tqdm to preprocessing --- examples/run_pplm_discrim_train.py | 206 ++++++++++++++--------------- 1 file changed, 102 insertions(+), 104 deletions(-) diff --git a/examples/run_pplm_discrim_train.py b/examples/run_pplm_discrim_train.py index 519e2de29a..5291ad4b51 100644 --- a/examples/run_pplm_discrim_train.py +++ b/examples/run_pplm_discrim_train.py @@ -18,13 +18,14 @@ import torch.utils.data as data from nltk.tokenize.treebank import TreebankWordDetokenizer from torchtext import data as torchtext_data from torchtext import datasets +from tqdm import tqdm, trange from transformers import GPT2Tokenizer, GPT2LMHeadModel torch.manual_seed(0) np.random.seed(0) EPSILON = 1e-10 -device = 'cpu' +device = "cpu" example_sentence = "This is incredible! I love it, this is the best chicken I have ever had." max_length_seq = 100 @@ -109,8 +110,8 @@ class Dataset(data.Dataset): def __getitem__(self, index): """Returns one data pair (source and target).""" data = {} - data['X'] = self.X[index] - data['y'] = self.y[index] + data["X"] = self.X[index] + data["y"] = self.y[index] return data @@ -133,8 +134,8 @@ def collate_fn(data): for key in data[0].keys(): item_info[key] = [d[key] for d in data] - x_batch, _ = pad_sequences(item_info['X']) - y_batch = torch.tensor(item_info['y'], dtype=torch.long) + x_batch, _ = pad_sequences(item_info["X"]) + y_batch = torch.tensor(item_info["y"], dtype=torch.long) return x_batch, y_batch @@ -144,8 +145,8 @@ def cached_collate_fn(data): for key in data[0].keys(): item_info[key] = [d[key] for d in data] - x_batch = torch.cat(item_info['X'], 0) - y_batch = torch.tensor(item_info['y'], dtype=torch.long) + x_batch = torch.cat(item_info["X"], 0) + y_batch = torch.tensor(item_info["y"], dtype=torch.long) return x_batch, y_batch @@ -168,7 +169,7 @@ def train_epoch(data_loader, discriminator, optimizer, if batch_idx % log_interval == 0: print( - 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch + 1, samples_so_far, len(data_loader.dataset), 100 * samples_so_far / len(data_loader.dataset), loss.item() @@ -185,7 +186,7 @@ def evaluate_performance(data_loader, discriminator): input_t, target_t = input_t.to(device), target_t.to(device) output_t = discriminator(input_t) # sum up batch loss - test_loss += F.nll_loss(output_t, target_t, reduction='sum').item() + test_loss += F.nll_loss(output_t, target_t, reduction="sum").item() # get the index of the max log-probability pred_t = output_t.argmax(dim=1, keepdim=True) correct += pred_t.eq(target_t.view_as(pred_t)).sum().item() @@ -193,8 +194,8 @@ def evaluate_performance(data_loader, discriminator): test_loss /= len(data_loader.dataset) print( - 'Performance on test set: ' - 'Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( + "Performance on test set: " + "Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format( test_loss, correct, len(data_loader.dataset), 100. * correct / len(data_loader.dataset) ) @@ -208,8 +209,8 @@ def predict(input_sentence, model, classes, cached=False): input_t = model.avg_representation(input_t) log_probs = model(input_t).data.cpu().numpy().flatten().tolist() - print('Input sentence:', input_sentence) - print('Predictions:', ", ".join( + print("Input sentence:", input_sentence) + print("Predictions:", ", ".join( "{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in zip(classes, log_probs) )) @@ -222,7 +223,7 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): xs = [] ys = [] - for batch_idx, (x, y) in enumerate(data_loader): + for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)): with torch.no_grad(): x = x.to(device) avg_rep = discriminator.avg_representation(x).cpu().detach() @@ -240,16 +241,16 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): def train_discriminator( - dataset, dataset_fp=None, pretrained_model='gpt2-medium', + dataset, dataset_fp=None, pretrained_model="gpt2-medium", epochs=10, batch_size=64, log_interval=10, save_model=False, cached=False, no_cuda=False): global device device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" - print('Preprocessing {} dataset...'.format(dataset)) + print("Preprocessing {} dataset...".format(dataset)) start = time.time() - if dataset == 'SST': + if dataset == "SST": idx2class = ["positive", "negative", "very positive", "very negative", "neutral"] class2idx = {c: i for i, c in enumerate(idx2class)} @@ -271,7 +272,7 @@ def train_discriminator( x = [] y = [] - for i in range(len(train_data)): + for i in trange(len(train_data), ascii=True): seq = TreebankWordDetokenizer().detokenize( vars(train_data[i])["text"] ) @@ -283,7 +284,7 @@ def train_discriminator( test_x = [] test_y = [] - for i in range(len(test_data)): + for i in trange(len(test_data), ascii=True): seq = TreebankWordDetokenizer().detokenize( vars(test_data[i])["text"] ) @@ -301,7 +302,7 @@ def train_discriminator( "default_class": 2, } - elif dataset == 'clickbait': + elif dataset == "clickbait": idx2class = ["non_clickbait", "clickbait"] class2idx = {c: i for i, c in enumerate(idx2class)} @@ -317,31 +318,33 @@ def train_discriminator( try: data.append(eval(line)) except: - print('Error evaluating line {}: {}'.format( + print("Error evaluating line {}: {}".format( i, line )) continue x = [] y = [] - y = [] - for i, d in enumerate(data): - try: - seq = discriminator.tokenizer.encode(d["text"]) + with open("datasets/clickbait/clickbait_train_prefix.txt") as f: + for i, line in enumerate(tqdm(f, ascii=True)): + try: + d = eval(line) + seq = discriminator.tokenizer.encode(d["text"]) - if len(seq) < max_length_seq: - seq = torch.tensor( - [50256] + seq, device=device, dtype=torch.long - ) - else: - print("Line {} is longer than maximum length {}".format( - i, max_length_seq - )) - continue - x.append(seq) - y.append(d['label']) - except: - print("Error tokenizing line {}, skipping it".format(i)) - pass + if len(seq) < max_length_seq: + seq = torch.tensor( + [50256] + seq, device=device, dtype=torch.long + ) + else: + print("Line {} is longer than maximum length {}".format( + i, max_length_seq + )) + continue + x.append(seq) + y.append(d["label"]) + except: + print("Error evaluating / tokenizing" + " line {}, skipping it".format(i)) + pass full_dataset = Dataset(x, y) train_size = int(0.9 * len(full_dataset)) @@ -358,7 +361,7 @@ def train_discriminator( "default_class": 1, } - elif dataset == 'toxic': + elif dataset == "toxic": idx2class = ["non_toxic", "toxic"] class2idx = {c: i for i, c in enumerate(idx2class)} @@ -368,37 +371,29 @@ def train_discriminator( cached_mode=cached ).to(device) - with open("datasets/toxic/toxic_train.txt") as f: - data = [] - for i, line in enumerate(f): - try: - data.append(eval(line)) - except: - print('Error evaluating line {}: {}'.format( - i, line - )) - continue - x = [] y = [] - for i, d in enumerate(data): - try: - seq = discriminator.tokenizer.encode(d["text"]) + with open("datasets/toxic/toxic_train.txt") as f: + for i, line in enumerate(tqdm(f, ascii=True)): + try: + d = eval(line) + seq = discriminator.tokenizer.encode(d["text"]) - if len(seq) < max_length_seq: - seq = torch.tensor( - [50256] + seq, device=device, dtype=torch.long - ) - else: - print("Line {} is longer than maximum length {}".format( - i, max_length_seq - )) - continue - x.append(seq) - y.append(int(np.sum(d['label']) > 0)) - except: - print("Error tokenizing line {}, skipping it".format(i)) - pass + if len(seq) < max_length_seq: + seq = torch.tensor( + [50256] + seq, device=device, dtype=torch.long + ) + else: + print("Line {} is longer than maximum length {}".format( + i, max_length_seq + )) + continue + x.append(seq) + y.append(int(np.sum(d["label"]) > 0)) + except: + print("Error evaluating / tokenizing" + " line {}, skipping it".format(i)) + pass full_dataset = Dataset(x, y) train_size = int(0.9 * len(full_dataset)) @@ -415,18 +410,18 @@ def train_discriminator( "default_class": 0, } - else: # if dataset == 'generic': + else: # if dataset == "generic": # This assumes the input dataset is a TSV with the following structure: # class \t text if dataset_fp is None: - raise ValueError('When generic dataset is selected, ' - 'dataset_fp needs to be specified aswell.') + raise ValueError("When generic dataset is selected, " + "dataset_fp needs to be specified aswell.") classes = set() with open(dataset_fp) as f: - csv_reader = csv.reader(f, delimiter='\t') - for row in csv_reader: + csv_reader = csv.reader(f, delimiter="\t") + for row in tqdm(csv_reader, ascii=True): if row: classes.add(row[0]) @@ -442,8 +437,8 @@ def train_discriminator( x = [] y = [] with open(dataset_fp) as f: - csv_reader = csv.reader(f, delimiter='\t') - for i, row in enumerate(csv_reader): + csv_reader = csv.reader(f, delimiter="\t") + for i, row in enumerate(tqdm(csv_reader, ascii=True)): if row: label = row[0] text = row[1] @@ -458,9 +453,10 @@ def train_discriminator( ) else: - print("Line {} is longer than maximum length {}".format( - i, max_length_seq - )) + print( + "Line {} is longer than maximum length {}".format( + i, max_length_seq + )) continue x.append(seq) @@ -487,12 +483,14 @@ def train_discriminator( } end = time.time() - print('Preprocessed {} data points'.format( + print("Preprocessed {} data points".format( len(train_dataset) + len(test_dataset)) ) print("Data preprocessing took: {:.3f}s".format(end - start)) if cached: + print("Building representation cache...") + start = time.time() train_loader = get_cached_data_loader( @@ -524,7 +522,7 @@ def train_discriminator( for epoch in range(epochs): start = time.time() - print('\nEpoch', epoch + 1) + print("\nEpoch", epoch + 1) train_epoch( discriminator=discriminator, @@ -553,31 +551,31 @@ def train_discriminator( "{}_classifier_head_epoch_{}.pt".format(dataset, epoch)) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser( - description='Train a discriminator on top of GPT-2 representations') - parser.add_argument('--dataset', type=str, default='SST', - choices=('SST', 'clickbait', 'toxic', 'generic'), - help='dataset to train the discriminator on.' - 'In case of generic, the dataset is expected' - 'to be a TSBV file with structure: class \\t text') - parser.add_argument('--dataset_fp', type=str, default='', - help='File path of the dataset to use. ' - 'Needed only in case of generic datadset') - parser.add_argument('--pretrained_model', type=str, default='gpt2-medium', - help='Pretrained model to use as encoder') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='Number of training epochs') - parser.add_argument('--batch_size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--log_interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument('--save_model', action='store_true', - help='whether to save the model') - parser.add_argument('--cached', action='store_true', - help='whether to cache the input representations') - parser.add_argument('--no_cuda', action='store_true', - help='use to turn off cuda') + description="Train a discriminator on top of GPT-2 representations") + parser.add_argument("--dataset", type=str, default="SST", + choices=("SST", "clickbait", "toxic", "generic"), + help="dataset to train the discriminator on." + "In case of generic, the dataset is expected" + "to be a TSBV file with structure: class \\t text") + parser.add_argument("--dataset_fp", type=str, default="", + help="File path of the dataset to use. " + "Needed only in case of generic datadset") + parser.add_argument("--pretrained_model", type=str, default="gpt2-medium", + help="Pretrained model to use as encoder") + parser.add_argument("--epochs", type=int, default=10, metavar="N", + help="Number of training epochs") + parser.add_argument("--batch_size", type=int, default=64, metavar="N", + help="input batch size for training (default: 64)") + parser.add_argument("--log_interval", type=int, default=10, metavar="N", + help="how many batches to wait before logging training status") + parser.add_argument("--save_model", action="store_true", + help="whether to save the model") + parser.add_argument("--cached", action="store_true", + help="whether to cache the input representations") + parser.add_argument("--no_cuda", action="store_true", + help="use to turn off cuda") args = parser.parse_args() train_discriminator(**(vars(args)))