The input training data files (multiple files in glob format). (#7717)
Very often splitting large files to smaller files can prevent tokenizer going out of memory in environment like Colab that does not have swap memory
This commit is contained in:
@@ -24,8 +24,11 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from glob import glob
|
||||
from typing import Optional
|
||||
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
@@ -87,6 +90,12 @@ class DataTrainingArguments:
|
||||
train_data_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input training data file (a text file)."}
|
||||
)
|
||||
train_data_files: Optional[str] = field(
|
||||
default=None, metadata={
|
||||
"help": "The input training data files (multiple files in glob format). "
|
||||
"Very often splitting large files to smaller files can prevent tokenizer going out of memory"
|
||||
}
|
||||
)
|
||||
eval_data_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
@@ -131,17 +140,24 @@ def get_dataset(
|
||||
evaluate: bool = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
):
|
||||
file_path = args.eval_data_file if evaluate else args.train_data_file
|
||||
if args.line_by_line:
|
||||
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||
def _dataset(file_path):
|
||||
if args.line_by_line:
|
||||
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||
else:
|
||||
return TextDataset(
|
||||
tokenizer=tokenizer,
|
||||
file_path=file_path,
|
||||
block_size=args.block_size,
|
||||
overwrite_cache=args.overwrite_cache,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
if evaluate:
|
||||
return _dataset(args.eval_data_file)
|
||||
elif args.train_data_files:
|
||||
return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
|
||||
else:
|
||||
return TextDataset(
|
||||
tokenizer=tokenizer,
|
||||
file_path=file_path,
|
||||
block_size=args.block_size,
|
||||
overwrite_cache=args.overwrite_cache,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
return _dataset(args.train_data_file)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user