The input training data files (multiple files in glob format). (#7717)
Very often splitting large files to smaller files can prevent tokenizer going out of memory in environment like Colab that does not have swap memory
This commit is contained in:
@@ -24,8 +24,11 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from glob import glob
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch.utils.data import ConcatDataset
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
@@ -87,6 +90,12 @@ class DataTrainingArguments:
|
|||||||
train_data_file: Optional[str] = field(
|
train_data_file: Optional[str] = field(
|
||||||
default=None, metadata={"help": "The input training data file (a text file)."}
|
default=None, metadata={"help": "The input training data file (a text file)."}
|
||||||
)
|
)
|
||||||
|
train_data_files: Optional[str] = field(
|
||||||
|
default=None, metadata={
|
||||||
|
"help": "The input training data files (multiple files in glob format). "
|
||||||
|
"Very often splitting large files to smaller files can prevent tokenizer going out of memory"
|
||||||
|
}
|
||||||
|
)
|
||||||
eval_data_file: Optional[str] = field(
|
eval_data_file: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||||
@@ -131,7 +140,7 @@ def get_dataset(
|
|||||||
evaluate: bool = False,
|
evaluate: bool = False,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
file_path = args.eval_data_file if evaluate else args.train_data_file
|
def _dataset(file_path):
|
||||||
if args.line_by_line:
|
if args.line_by_line:
|
||||||
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||||
else:
|
else:
|
||||||
@@ -143,6 +152,13 @@ def get_dataset(
|
|||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if evaluate:
|
||||||
|
return _dataset(args.eval_data_file)
|
||||||
|
elif args.train_data_files:
|
||||||
|
return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
|
||||||
|
else:
|
||||||
|
return _dataset(args.train_data_file)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
|
|||||||
Reference in New Issue
Block a user