Add cache_dir to save features TextDataset (#6879)
* Add cache_dir to save features TextDataset This is in case the dataset is in a RO filesystem, for which is the case in tests (GKE TPU tests). * style
This commit is contained in:
committed by
GitHub
parent
1461aac8d7
commit
21d719238c
@@ -125,13 +125,22 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
|
def get_dataset(
|
||||||
|
args: DataTrainingArguments,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
evaluate: bool = False,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
file_path = args.eval_data_file if evaluate else args.train_data_file
|
file_path = args.eval_data_file if evaluate else args.train_data_file
|
||||||
if args.line_by_line:
|
if args.line_by_line:
|
||||||
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||||
else:
|
else:
|
||||||
return TextDataset(
|
return TextDataset(
|
||||||
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache
|
tokenizer=tokenizer,
|
||||||
|
file_path=file_path,
|
||||||
|
block_size=args.block_size,
|
||||||
|
overwrite_cache=args.overwrite_cache,
|
||||||
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -229,8 +238,14 @@ def main():
|
|||||||
|
|
||||||
# Get datasets
|
# Get datasets
|
||||||
|
|
||||||
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
train_dataset = (
|
||||||
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
|
||||||
|
)
|
||||||
|
eval_dataset = (
|
||||||
|
get_dataset(data_args, tokenizer=tokenizer, evaluate=True, cache_dir=model_args.cache_dir)
|
||||||
|
if training_args.do_eval
|
||||||
|
else None
|
||||||
|
)
|
||||||
if config.model_type == "xlnet":
|
if config.model_type == "xlnet":
|
||||||
data_collator = DataCollatorForPermutationLanguageModeling(
|
data_collator = DataCollatorForPermutationLanguageModeling(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
@@ -26,6 +27,7 @@ class TextDataset(Dataset):
|
|||||||
file_path: str,
|
file_path: str,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
overwrite_cache=False,
|
overwrite_cache=False,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
||||||
|
|
||||||
@@ -33,7 +35,7 @@ class TextDataset(Dataset):
|
|||||||
|
|
||||||
directory, filename = os.path.split(file_path)
|
directory, filename = os.path.split(file_path)
|
||||||
cached_features_file = os.path.join(
|
cached_features_file = os.path.join(
|
||||||
directory,
|
cache_dir if cache_dir is not None else directory,
|
||||||
"cached_lm_{}_{}_{}".format(
|
"cached_lm_{}_{}_{}".format(
|
||||||
tokenizer.__class__.__name__,
|
tokenizer.__class__.__name__,
|
||||||
str(block_size),
|
str(block_size),
|
||||||
|
|||||||
Reference in New Issue
Block a user