The input training data files (multiple files in glob format). (#7717)

Very often splitting large files to smaller files can prevent tokenizer going out of memory in environment like Colab that does not have swap memory
This commit is contained in:
Kelvin
2020-10-12 12:44:02 +01:00
committed by GitHub
parent 34fcfb44e3
commit f176e70723

View File

@@ -24,8 +24,11 @@ import logging
import math
import os
from dataclasses import dataclass, field
from glob import glob
from typing import Optional
from torch.utils.data import ConcatDataset
from transformers import (
CONFIG_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
@@ -87,6 +90,12 @@ class DataTrainingArguments:
train_data_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a text file)."}
)
train_data_files: Optional[str] = field(
default=None, metadata={
"help": "The input training data files (multiple files in glob format). "
"Very often splitting large files to smaller files can prevent tokenizer going out of memory"
}
)
eval_data_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
@@ -131,17 +140,24 @@ def get_dataset(
evaluate: bool = False,
cache_dir: Optional[str] = None,
):
file_path = args.eval_data_file if evaluate else args.train_data_file
if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
def _dataset(file_path):
if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else:
return TextDataset(
tokenizer=tokenizer,
file_path=file_path,
block_size=args.block_size,
overwrite_cache=args.overwrite_cache,
cache_dir=cache_dir,
)
if evaluate:
return _dataset(args.eval_data_file)
elif args.train_data_files:
return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
else:
return TextDataset(
tokenizer=tokenizer,
file_path=file_path,
block_size=args.block_size,
overwrite_cache=args.overwrite_cache,
cache_dir=cache_dir,
)
return _dataset(args.train_data_file)
def main():