diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index e7507b7d02..e028900d30 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -32,6 +32,7 @@ from typing import Dict, List, Tuple import numpy as np import torch +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange @@ -83,7 +84,7 @@ MODEL_CLASSES = { class TextDataset(Dataset): - def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path="train", block_size=512): + def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512): assert os.path.isfile(file_path) directory, filename = os.path.split(file_path) cached_features_file = os.path.join( @@ -120,13 +121,32 @@ class TextDataset(Dataset): return torch.tensor(self.examples[item]) +class LineByLineTextDataset(Dataset): + def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512): + assert os.path.isfile(file_path) + # Here, we do not cache the features, operating under the assumption + # that we will soon use fast multithreaded tokenizers from the + # `tokenizers` repo everywhere =) + logger.info("Creating features from dataset file at %s", file_path) + + with open(file_path, encoding="utf-8") as f: + lines = [line for line in f.read().splitlines() if len(line) > 0] + + self.examples = tokenizer.batch_encode_plus(lines, max_length=block_size)["input_ids"] + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + return torch.tensor(self.examples[i]) + + def load_and_cache_examples(args, tokenizer, evaluate=False): - return TextDataset( - tokenizer, - args, - file_path=args.eval_data_file if evaluate else args.train_data_file, - block_size=args.block_size, - ) + file_path = args.eval_data_file if evaluate else args.train_data_file + if args.line_by_line: + return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size) + else: + return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size) def set_seed(args): @@ -182,6 +202,8 @@ def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> T tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() ] probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) + padding_mask = labels.eq(tokenizer.pad_token_id) + probability_matrix.masked_fill_(padding_mask, value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() labels[~masked_indices] = -100 # We only compute loss on masked tokens @@ -204,8 +226,14 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + + def collate(examples: List[torch.Tensor]): + return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) + train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) - train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + train_dataloader = DataLoader( + train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate + ) if args.max_steps > 0: t_total = args.max_steps @@ -391,8 +419,14 @@ def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefi args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) # Note that DistributedSampler samples randomly + + def collate(examples: List[torch.Tensor]): + return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) + eval_sampler = SequentialSampler(eval_dataset) - eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + eval_dataloader = DataLoader( + eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate + ) # multi-gpu evaluate if args.n_gpu > 1: @@ -456,11 +490,14 @@ def main(): type=str, help="An optional input evaluation data file to evaluate the perplexity on (a text file).", ) - + parser.add_argument( + "--line_by_line", + action="store_true", + help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.", + ) parser.add_argument( "--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir" ) - parser.add_argument( "--model_name_or_path", default=None,