Simplified data generator + better perplexity calculator
GPT-2 now obtains ~20 perplexity on WikiText-2
This commit is contained in:
@@ -85,7 +85,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = SequentialSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=WikiTextDataset.collate)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
@@ -209,7 +209,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=WikiTextDataset.collate)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
@@ -217,12 +217,13 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
model.eval()
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = batch.to(args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(batch)
|
||||
outputs = model(batch, masked_lm_labels=batch) if args.mlm else model(batch, labels=batch)
|
||||
lm_loss = outputs[0]
|
||||
eval_loss += lm_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
|
||||
@@ -6,34 +6,21 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class WikiTextDataset(Dataset):
|
||||
def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=512):
|
||||
def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=1024):
|
||||
self.max_context_length = max_context_length
|
||||
|
||||
self.examples = []
|
||||
|
||||
with open(os.path.join(directory, f"wiki.{file}.raw"), encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
spans = list(filter(lambda item: len(item) > 120, text.split("\n")))
|
||||
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||
|
||||
for span in spans:
|
||||
span = tokenizer.encode(span)
|
||||
while len(span) > 0:
|
||||
self.examples.append(span[:max_context_length])
|
||||
span = span[max_context_length:]
|
||||
|
||||
# Randomly shuffle the examples array
|
||||
random.shuffle(self.examples)
|
||||
|
||||
# Sort the array by example length.
|
||||
self.examples.sort(key=len)
|
||||
while len(tokenized_text) > max_context_length:
|
||||
self.examples.append(tokenized_text[:max_context_length])
|
||||
tokenized_text = tokenized_text[max_context_length:]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return torch.tensor(self.examples[item])
|
||||
|
||||
@staticmethod
|
||||
def collate(values):
|
||||
stack = torch.stack([F.pad(value, (len(values[-1]) - value.size(0), 0), "constant", 0) for value in values])
|
||||
return stack
|
||||
|
||||
Reference in New Issue
Block a user