Add cache_dir to save features TextDataset (#6879)

* Add cache_dir to save features TextDataset

This is in case the dataset is in a RO filesystem, for which is the case
in tests (GKE TPU tests).

* style
This commit is contained in:
Jin Young (Daniel) Sohn
2020-09-01 08:42:17 -07:00
committed by GitHub
parent 1461aac8d7
commit 21d719238c
2 changed files with 22 additions and 5 deletions

View File

@@ -1,6 +1,7 @@
import os
import pickle
import time
from typing import Optional
import torch
from torch.utils.data.dataset import Dataset
@@ -26,6 +27,7 @@ class TextDataset(Dataset):
file_path: str,
block_size: int,
overwrite_cache=False,
cache_dir: Optional[str] = None,
):
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
@@ -33,7 +35,7 @@ class TextDataset(Dataset):
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(
directory,
cache_dir if cache_dir is not None else directory,
"cached_lm_{}_{}_{}".format(
tokenizer.__class__.__name__,
str(block_size),