BERT + RoBERTa masking tokens handling + GPU device update.
This commit is contained in:
@@ -65,11 +65,15 @@ def set_seed(args):
|
|||||||
def mask_tokens(inputs, tokenizer, args):
|
def mask_tokens(inputs, tokenizer, args):
|
||||||
labels = inputs.clone()
|
labels = inputs.clone()
|
||||||
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte()
|
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte()
|
||||||
labels[~masked_indices] = -1 # We only compute loss on masked tokens
|
labels[~masked_indices.bool()] = -1 # We only compute loss on masked tokens
|
||||||
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
|
||||||
inputs[indices_replaced] = tokenizer.vocab["[MASK]"] # 80% of the time, replace masked input tokens with [MASK]
|
|
||||||
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced
|
if args.model_name == "bert":
|
||||||
random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long, device=args.device)
|
inputs[indices_replaced.bool()] = tokenizer.vocab["[MASK]"] # 80% of the time, replace masked input tokens with [MASK]
|
||||||
|
elif args.model_name == "roberta":
|
||||||
|
inputs[indices_replaced.bool()] = tokenizer.encoder["<mask>"] # 80% of the time, replace masked input tokens with <mask>
|
||||||
|
indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced).bool()
|
||||||
|
random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long)
|
||||||
inputs[indices_random] = random_words[
|
inputs[indices_random] = random_words[
|
||||||
indices_random] # 10% of the time, replace masked input tokens with random word
|
indices_random] # 10% of the time, replace masked input tokens with random word
|
||||||
return inputs, labels
|
return inputs, labels
|
||||||
@@ -132,9 +136,10 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
batch.to(args.device)
|
|
||||||
model.train()
|
|
||||||
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
|
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
|
||||||
|
inputs = inputs.to(args.device)
|
||||||
|
labels = labels.to(args.device)
|
||||||
|
model.train()
|
||||||
outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
|
outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
|
||||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||||
|
|
||||||
@@ -214,7 +219,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
nb_eval_steps = 0
|
nb_eval_steps = 0
|
||||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
model.eval()
|
model.eval()
|
||||||
batch.to(args.device)
|
batch = batch.to(args.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(batch)
|
outputs = model(batch)
|
||||||
@@ -285,9 +290,9 @@ def main():
|
|||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument("--do_lower_case", action='store_true',
|
||||||
help="Set this flag if you are using an uncased model.")
|
help="Set this flag if you are using an uncased model.")
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
|
||||||
help="Batch size per GPU/CPU for training.")
|
help="Batch size per GPU/CPU for training.")
|
||||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
|
||||||
help="Batch size per GPU/CPU for evaluation.")
|
help="Batch size per GPU/CPU for evaluation.")
|
||||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||||
@@ -299,7 +304,7 @@ def main():
|
|||||||
help="Epsilon for Adam optimizer.")
|
help="Epsilon for Adam optimizer.")
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||||
help="Max gradient norm.")
|
help="Max gradient norm.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument("--num_train_epochs", default=1.0, type=float,
|
||||||
help="Total number of training epochs to perform.")
|
help="Total number of training epochs to perform.")
|
||||||
parser.add_argument("--max_steps", default=-1, type=int,
|
parser.add_argument("--max_steps", default=-1, type=int,
|
||||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
|
|
||||||
class WikiTextDataset(Dataset):
|
class WikiTextDataset(Dataset):
|
||||||
def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=512, device='cpu'):
|
def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=512):
|
||||||
self.device = device
|
|
||||||
self.max_context_length = max_context_length
|
self.max_context_length = max_context_length
|
||||||
|
|
||||||
self.examples = []
|
self.examples = []
|
||||||
@@ -32,7 +31,7 @@ class WikiTextDataset(Dataset):
|
|||||||
return len(self.examples)
|
return len(self.examples)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return torch.tensor(self.examples[item], device=self.device)
|
return torch.tensor(self.examples[item])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def collate(values):
|
def collate(values):
|
||||||
|
|||||||
Reference in New Issue
Block a user