From 715534800a2a809dbfc66bd17acb36ed30999b0d Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 14 Aug 2019 09:52:57 -0400 Subject: [PATCH] BERT + RoBERTa masking tokens handling + GPU device update. --- examples/run_generative_finetuning.py | 27 ++++++++++++++++----------- examples/utils_lm.py | 5 ++--- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/examples/run_generative_finetuning.py b/examples/run_generative_finetuning.py index 44daa3d266..ecbf44d8de 100644 --- a/examples/run_generative_finetuning.py +++ b/examples/run_generative_finetuning.py @@ -65,11 +65,15 @@ def set_seed(args): def mask_tokens(inputs, tokenizer, args): labels = inputs.clone() masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte() - labels[~masked_indices] = -1 # We only compute loss on masked tokens + labels[~masked_indices.bool()] = -1 # We only compute loss on masked tokens indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices - inputs[indices_replaced] = tokenizer.vocab["[MASK]"] # 80% of the time, replace masked input tokens with [MASK] - indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced - random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long, device=args.device) + + if args.model_name == "bert": + inputs[indices_replaced.bool()] = tokenizer.vocab["[MASK]"] # 80% of the time, replace masked input tokens with [MASK] + elif args.model_name == "roberta": + inputs[indices_replaced.bool()] = tokenizer.encoder[""] # 80% of the time, replace masked input tokens with + indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced).bool() + random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long) inputs[indices_random] = random_words[ indices_random] # 10% of the time, replace masked input tokens with random word return inputs, labels @@ -132,14 +136,15 @@ def train(args, train_dataset, model, tokenizer): for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): - batch.to(args.device) - model.train() inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) + inputs = inputs.to(args.device) + labels = labels.to(args.device) + model.train() outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) if args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training + loss = loss.mean() # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps @@ -214,7 +219,7 @@ def evaluate(args, model, tokenizer, prefix=""): nb_eval_steps = 0 for batch in tqdm(eval_dataloader, desc="Evaluating"): model.eval() - batch.to(args.device) + batch = batch.to(args.device) with torch.no_grad(): outputs = model(batch) @@ -285,9 +290,9 @@ def main(): parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") - parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, + parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") - parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, + parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation.") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") @@ -299,7 +304,7 @@ def main(): help="Epsilon for Adam optimizer.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--num_train_epochs", default=3.0, type=float, + parser.add_argument("--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.") diff --git a/examples/utils_lm.py b/examples/utils_lm.py index 2944cdc9ea..68a1ca2cce 100644 --- a/examples/utils_lm.py +++ b/examples/utils_lm.py @@ -6,8 +6,7 @@ import torch.nn.functional as F class WikiTextDataset(Dataset): - def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=512, device='cpu'): - self.device = device + def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=512): self.max_context_length = max_context_length self.examples = [] @@ -32,7 +31,7 @@ class WikiTextDataset(Dataset): return len(self.examples) def __getitem__(self, item): - return torch.tensor(self.examples[item], device=self.device) + return torch.tensor(self.examples[item]) @staticmethod def collate(values):