Add cache_dir to save features TextDataset (#6879)
* Add cache_dir to save features TextDataset This is in case the dataset is in a RO filesystem, for which is the case in tests (GKE TPU tests). * style
This commit is contained in:
committed by
GitHub
parent
1461aac8d7
commit
21d719238c
@@ -125,13 +125,22 @@ class DataTrainingArguments:
|
||||
)
|
||||
|
||||
|
||||
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
|
||||
def get_dataset(
|
||||
args: DataTrainingArguments,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
evaluate: bool = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
):
|
||||
file_path = args.eval_data_file if evaluate else args.train_data_file
|
||||
if args.line_by_line:
|
||||
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||
else:
|
||||
return TextDataset(
|
||||
tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache
|
||||
tokenizer=tokenizer,
|
||||
file_path=file_path,
|
||||
block_size=args.block_size,
|
||||
overwrite_cache=args.overwrite_cache,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
|
||||
@@ -229,8 +238,14 @@ def main():
|
||||
|
||||
# Get datasets
|
||||
|
||||
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
||||
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
||||
train_dataset = (
|
||||
get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
|
||||
)
|
||||
eval_dataset = (
|
||||
get_dataset(data_args, tokenizer=tokenizer, evaluate=True, cache_dir=model_args.cache_dir)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
if config.model_type == "xlnet":
|
||||
data_collator = DataCollatorForPermutationLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user