Add cache_dir to save features TextDataset (#6879)

* Add cache_dir to save features TextDataset

This is in case the dataset is in a RO filesystem, for which is the case
in tests (GKE TPU tests).

* style
This commit is contained in:
Jin Young (Daniel) Sohn
2020-09-01 08:42:17 -07:00
committed by GitHub
parent 1461aac8d7
commit 21d719238c
2 changed files with 22 additions and 5 deletions

View File

@@ -125,13 +125,22 @@ class DataTrainingArguments:
)
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
def get_dataset(
args: DataTrainingArguments,
tokenizer: PreTrainedTokenizer,
evaluate: bool = False,
cache_dir: Optional[str] = None,
):
file_path = args.eval_data_file if evaluate else args.train_data_file
if args.line_by_line:
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else:
return TextDataset(
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache
tokenizer=tokenizer,
file_path=file_path,
block_size=args.block_size,
overwrite_cache=args.overwrite_cache,
cache_dir=cache_dir,
)
@@ -229,8 +238,14 @@ def main():
# Get datasets
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
train_dataset = (
get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
)
eval_dataset = (
get_dataset(data_args, tokenizer=tokenizer, evaluate=True, cache_dir=model_args.cache_dir)
if training_args.do_eval
else None
)
if config.model_type == "xlnet":
data_collator = DataCollatorForPermutationLanguageModeling(
tokenizer=tokenizer,