Add cache_dir to save features TextDataset (#6879)
* Add cache_dir to save features TextDataset This is in case the dataset is in a RO filesystem, for which is the case in tests (GKE TPU tests). * style
This commit is contained in:
committed by
GitHub
parent
1461aac8d7
commit
21d719238c
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
@@ -26,6 +27,7 @@ class TextDataset(Dataset):
|
||||
file_path: str,
|
||||
block_size: int,
|
||||
overwrite_cache=False,
|
||||
cache_dir: Optional[str] = None,
|
||||
):
|
||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
||||
|
||||
@@ -33,7 +35,7 @@ class TextDataset(Dataset):
|
||||
|
||||
directory, filename = os.path.split(file_path)
|
||||
cached_features_file = os.path.join(
|
||||
directory,
|
||||
cache_dir if cache_dir is not None else directory,
|
||||
"cached_lm_{}_{}_{}".format(
|
||||
tokenizer.__class__.__name__,
|
||||
str(block_size),
|
||||
|
||||
Reference in New Issue
Block a user