From 21d719238c68154798bff21581b82410f303e9ba Mon Sep 17 00:00:00 2001 From: "Jin Young (Daniel) Sohn" Date: Tue, 1 Sep 2020 08:42:17 -0700 Subject: [PATCH] Add cache_dir to save features TextDataset (#6879) * Add cache_dir to save features TextDataset This is in case the dataset is in a RO filesystem, for which is the case in tests (GKE TPU tests). * style --- .../run_language_modeling.py | 23 +++++++++++++++---- .../data/datasets/language_modeling.py | 4 +++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/examples/language-modeling/run_language_modeling.py b/examples/language-modeling/run_language_modeling.py index 3377b9d9cb..eb30258595 100644 --- a/examples/language-modeling/run_language_modeling.py +++ b/examples/language-modeling/run_language_modeling.py @@ -125,13 +125,22 @@ class DataTrainingArguments: ) -def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False): +def get_dataset( + args: DataTrainingArguments, + tokenizer: PreTrainedTokenizer, + evaluate: bool = False, + cache_dir: Optional[str] = None, +): file_path = args.eval_data_file if evaluate else args.train_data_file if args.line_by_line: return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) else: return TextDataset( - tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache + tokenizer=tokenizer, + file_path=file_path, + block_size=args.block_size, + overwrite_cache=args.overwrite_cache, + cache_dir=cache_dir, ) @@ -229,8 +238,14 @@ def main(): # Get datasets - train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None - eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None + train_dataset = ( + get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None + ) + eval_dataset = ( + get_dataset(data_args, tokenizer=tokenizer, evaluate=True, cache_dir=model_args.cache_dir) + if training_args.do_eval + else None + ) if config.model_type == "xlnet": data_collator = DataCollatorForPermutationLanguageModeling( tokenizer=tokenizer, diff --git a/src/transformers/data/datasets/language_modeling.py b/src/transformers/data/datasets/language_modeling.py index 1a377a60b1..6b0f287126 100644 --- a/src/transformers/data/datasets/language_modeling.py +++ b/src/transformers/data/datasets/language_modeling.py @@ -1,6 +1,7 @@ import os import pickle import time +from typing import Optional import torch from torch.utils.data.dataset import Dataset @@ -26,6 +27,7 @@ class TextDataset(Dataset): file_path: str, block_size: int, overwrite_cache=False, + cache_dir: Optional[str] = None, ): assert os.path.isfile(file_path), f"Input file path {file_path} not found" @@ -33,7 +35,7 @@ class TextDataset(Dataset): directory, filename = os.path.split(file_path) cached_features_file = os.path.join( - directory, + cache_dir if cache_dir is not None else directory, "cached_lm_{}_{}_{}".format( tokenizer.__class__.__name__, str(block_size),