From f176e707234ab4d6f2a44179066d71658cc40056 Mon Sep 17 00:00:00 2001 From: Kelvin Date: Mon, 12 Oct 2020 12:44:02 +0100 Subject: [PATCH] The input training data files (multiple files in glob format). (#7717) Very often splitting large files to smaller files can prevent tokenizer going out of memory in environment like Colab that does not have swap memory --- .../run_language_modeling.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/examples/language-modeling/run_language_modeling.py b/examples/language-modeling/run_language_modeling.py index eb30258595..283ade210c 100644 --- a/examples/language-modeling/run_language_modeling.py +++ b/examples/language-modeling/run_language_modeling.py @@ -24,8 +24,11 @@ import logging import math import os from dataclasses import dataclass, field +from glob import glob from typing import Optional +from torch.utils.data import ConcatDataset + from transformers import ( CONFIG_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, @@ -87,6 +90,12 @@ class DataTrainingArguments: train_data_file: Optional[str] = field( default=None, metadata={"help": "The input training data file (a text file)."} ) + train_data_files: Optional[str] = field( + default=None, metadata={ + "help": "The input training data files (multiple files in glob format). " + "Very often splitting large files to smaller files can prevent tokenizer going out of memory" + } + ) eval_data_file: Optional[str] = field( default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, @@ -131,17 +140,24 @@ def get_dataset( evaluate: bool = False, cache_dir: Optional[str] = None, ): - file_path = args.eval_data_file if evaluate else args.train_data_file - if args.line_by_line: - return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) + def _dataset(file_path): + if args.line_by_line: + return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) + else: + return TextDataset( + tokenizer=tokenizer, + file_path=file_path, + block_size=args.block_size, + overwrite_cache=args.overwrite_cache, + cache_dir=cache_dir, + ) + + if evaluate: + return _dataset(args.eval_data_file) + elif args.train_data_files: + return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)]) else: - return TextDataset( - tokenizer=tokenizer, - file_path=file_path, - block_size=args.block_size, - overwrite_cache=args.overwrite_cache, - cache_dir=cache_dir, - ) + return _dataset(args.train_data_file) def main():