From 5652f54ac26f3233f4dcbfd9a2f6879e94a0bc59 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 16 Aug 2019 13:49:56 -0400 Subject: [PATCH] Simplified data generator + better perplexity calculator GPT-2 now obtains ~20 perplexity on WikiText-2 --- examples/run_generative_finetuning.py | 9 +++++---- examples/utils_lm.py | 23 +++++------------------ 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/examples/run_generative_finetuning.py b/examples/run_generative_finetuning.py index ecbf44d8de..bb6aee6f07 100644 --- a/examples/run_generative_finetuning.py +++ b/examples/run_generative_finetuning.py @@ -85,7 +85,7 @@ def train(args, train_dataset, model, tokenizer): args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) - train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=WikiTextDataset.collate) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps @@ -209,7 +209,7 @@ def evaluate(args, model, tokenizer, prefix=""): args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) # Note that DistributedSampler samples randomly eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) - eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=WikiTextDataset.collate) + eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) @@ -217,12 +217,13 @@ def evaluate(args, model, tokenizer, prefix=""): logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 + model.eval() + for batch in tqdm(eval_dataloader, desc="Evaluating"): - model.eval() batch = batch.to(args.device) with torch.no_grad(): - outputs = model(batch) + outputs = model(batch, masked_lm_labels=batch) if args.mlm else model(batch, labels=batch) lm_loss = outputs[0] eval_loss += lm_loss.mean().item() nb_eval_steps += 1 diff --git a/examples/utils_lm.py b/examples/utils_lm.py index 68a1ca2cce..5f22e10a76 100644 --- a/examples/utils_lm.py +++ b/examples/utils_lm.py @@ -6,34 +6,21 @@ import torch.nn.functional as F class WikiTextDataset(Dataset): - def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=512): + def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=1024): self.max_context_length = max_context_length self.examples = [] with open(os.path.join(directory, f"wiki.{file}.raw"), encoding="utf-8") as f: text = f.read() - spans = list(filter(lambda item: len(item) > 120, text.split("\n"))) + tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) - for span in spans: - span = tokenizer.encode(span) - while len(span) > 0: - self.examples.append(span[:max_context_length]) - span = span[max_context_length:] - - # Randomly shuffle the examples array - random.shuffle(self.examples) - - # Sort the array by example length. - self.examples.sort(key=len) + while len(tokenized_text) > max_context_length: + self.examples.append(tokenized_text[:max_context_length]) + tokenized_text = tokenized_text[max_context_length:] def __len__(self): return len(self.examples) def __getitem__(self, item): return torch.tensor(self.examples[item]) - - @staticmethod - def collate(values): - stack = torch.stack([F.pad(value, (len(values[-1]) - value.size(0), 0), "constant", 0) for value in values]) - return stack