From f94f1c6016414e059fa4d8ef61ee194fdc891046 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 19 Aug 2019 14:58:50 -0400 Subject: [PATCH] Distributed training + tokenizer agnostic mask token --- examples/run_generative_finetuning.py | 14 +++----------- examples/utils_lm.py | 27 ++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/examples/run_generative_finetuning.py b/examples/run_generative_finetuning.py index bb6aee6f07..8501364ae4 100644 --- a/examples/run_generative_finetuning.py +++ b/examples/run_generative_finetuning.py @@ -39,12 +39,10 @@ from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT BertConfig, BertForMaskedLM, BertTokenizer, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) from pytorch_transformers import AdamW, WarmupLinearSchedule +logger = logging.getLogger(__name__) from utils_lm import WikiTextDataset -logger = logging.getLogger(__name__) - -ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ()) MODEL_CLASSES = { 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), @@ -68,10 +66,7 @@ def mask_tokens(inputs, tokenizer, args): labels[~masked_indices.bool()] = -1 # We only compute loss on masked tokens indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices - if args.model_name == "bert": - inputs[indices_replaced.bool()] = tokenizer.vocab["[MASK]"] # 80% of the time, replace masked input tokens with [MASK] - elif args.model_name == "roberta": - inputs[indices_replaced.bool()] = tokenizer.encoder[""] # 80% of the time, replace masked input tokens with + inputs[indices_replaced.bool()] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) # 80% of the time, replace masked input tokens with [MASK] indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced).bool() random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long) inputs[indices_random] = random_words[ @@ -246,10 +241,7 @@ def evaluate(args, model, tokenizer, prefix=""): def load_and_cache_examples(args, tokenizer, evaluate=False): - if args.local_rank not in [-1, 0]: - torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache - - dataset = WikiTextDataset(tokenizer, file="test" if evaluate else "train", directory=args.data_dir) + dataset = WikiTextDataset(args, tokenizer, file="test" if evaluate else "train", directory=args.data_dir) return dataset diff --git a/examples/utils_lm.py b/examples/utils_lm.py index 5f22e10a76..251aea90e1 100644 --- a/examples/utils_lm.py +++ b/examples/utils_lm.py @@ -3,10 +3,27 @@ import os import random import torch import torch.nn.functional as F +import logging +import pickle + +logger = logging.getLogger(__name__) class WikiTextDataset(Dataset): - def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=1024): + def __init__(self, args, tokenizer, file='train', directory='wikitext', max_context_length=512, cache=None): + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + + cached_features_file = os.path.join(args.data_dir, f'cached_lm_{file}_{args.max_seq_length}') + + if os.path.exists(cached_features_file): + logger.info("Loading features from cached file %s", cached_features_file) + with open(cached_features_file, 'rb') as handle: + self.examples = pickle.load(handle) + else: + logger.info("Creating features from dataset file at %s", args.data_dir) + self.max_context_length = max_context_length self.examples = [] @@ -18,6 +35,14 @@ class WikiTextDataset(Dataset): while len(tokenized_text) > max_context_length: self.examples.append(tokenized_text[:max_context_length]) tokenized_text = tokenized_text[max_context_length:] + + if args.local_rank in [-1, 0]: + logger.info("Saving features into cached file %s", cached_features_file) + with open(cached_features_file, 'wb') as handle: + pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) + + if args.local_rank == 0: + torch.distributed.barrier() def __len__(self): return len(self.examples)