Distributed training + tokenizer agnostic mask token
This commit is contained in:
@@ -39,12 +39,10 @@ from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT
|
|||||||
BertConfig, BertForMaskedLM, BertTokenizer, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertConfig, BertForMaskedLM, BertTokenizer, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from utils_lm import WikiTextDataset
|
from utils_lm import WikiTextDataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ())
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||||
@@ -68,10 +66,7 @@ def mask_tokens(inputs, tokenizer, args):
|
|||||||
labels[~masked_indices.bool()] = -1 # We only compute loss on masked tokens
|
labels[~masked_indices.bool()] = -1 # We only compute loss on masked tokens
|
||||||
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
|
||||||
|
|
||||||
if args.model_name == "bert":
|
inputs[indices_replaced.bool()] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) # 80% of the time, replace masked input tokens with [MASK]
|
||||||
inputs[indices_replaced.bool()] = tokenizer.vocab["[MASK]"] # 80% of the time, replace masked input tokens with [MASK]
|
|
||||||
elif args.model_name == "roberta":
|
|
||||||
inputs[indices_replaced.bool()] = tokenizer.encoder["<mask>"] # 80% of the time, replace masked input tokens with <mask>
|
|
||||||
indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced).bool()
|
indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced).bool()
|
||||||
random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long)
|
random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long)
|
||||||
inputs[indices_random] = random_words[
|
inputs[indices_random] = random_words[
|
||||||
@@ -246,10 +241,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
||||||
if args.local_rank not in [-1, 0]:
|
dataset = WikiTextDataset(args, tokenizer, file="test" if evaluate else "train", directory=args.data_dir)
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
|
||||||
|
|
||||||
dataset = WikiTextDataset(tokenizer, file="test" if evaluate else "train", directory=args.data_dir)
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,27 @@ import os
|
|||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WikiTextDataset(Dataset):
|
class WikiTextDataset(Dataset):
|
||||||
def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=1024):
|
def __init__(self, args, tokenizer, file='train', directory='wikitext', max_context_length=512, cache=None):
|
||||||
|
if args.local_rank not in [-1, 0]:
|
||||||
|
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
|
||||||
|
|
||||||
|
cached_features_file = os.path.join(args.data_dir, f'cached_lm_{file}_{args.max_seq_length}')
|
||||||
|
|
||||||
|
if os.path.exists(cached_features_file):
|
||||||
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
|
with open(cached_features_file, 'rb') as handle:
|
||||||
|
self.examples = pickle.load(handle)
|
||||||
|
else:
|
||||||
|
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||||
|
|
||||||
self.max_context_length = max_context_length
|
self.max_context_length = max_context_length
|
||||||
|
|
||||||
self.examples = []
|
self.examples = []
|
||||||
@@ -19,6 +36,14 @@ class WikiTextDataset(Dataset):
|
|||||||
self.examples.append(tokenized_text[:max_context_length])
|
self.examples.append(tokenized_text[:max_context_length])
|
||||||
tokenized_text = tokenized_text[max_context_length:]
|
tokenized_text = tokenized_text[max_context_length:]
|
||||||
|
|
||||||
|
if args.local_rank in [-1, 0]:
|
||||||
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
|
with open(cached_features_file, 'wb') as handle:
|
||||||
|
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
if args.local_rank == 0:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.examples)
|
return len(self.examples)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user