The input training data files (multiple files in glob format). (#7717)

Very often splitting large files to smaller files can prevent tokenizer going out of memory in environment like Colab that does not have swap memory
This commit is contained in:
Kelvin
2020-10-12 12:44:02 +01:00
committed by GitHub
parent 34fcfb44e3
commit f176e70723

View File

@@ -24,8 +24,11 @@ import logging
import math import math
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from glob import glob
from typing import Optional from typing import Optional
from torch.utils.data import ConcatDataset
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
@@ -87,6 +90,12 @@ class DataTrainingArguments:
train_data_file: Optional[str] = field( train_data_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a text file)."} default=None, metadata={"help": "The input training data file (a text file)."}
) )
train_data_files: Optional[str] = field(
default=None, metadata={
"help": "The input training data files (multiple files in glob format). "
"Very often splitting large files to smaller files can prevent tokenizer going out of memory"
}
)
eval_data_file: Optional[str] = field( eval_data_file: Optional[str] = field(
default=None, default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
@@ -131,7 +140,7 @@ def get_dataset(
evaluate: bool = False, evaluate: bool = False,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
): ):
file_path = args.eval_data_file if evaluate else args.train_data_file def _dataset(file_path):
if args.line_by_line: if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else: else:
@@ -143,6 +152,13 @@ def get_dataset(
cache_dir=cache_dir, cache_dir=cache_dir,
) )
if evaluate:
return _dataset(args.eval_data_file)
elif args.train_data_files:
return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
else:
return _dataset(args.train_data_file)
def main(): def main():
# See all possible arguments in src/transformers/training_args.py # See all possible arguments in src/transformers/training_args.py