[squad] make examples and dataset accessible from SquadDataset object (#6710)
* [squad] make examples and dataset accessible from SquadDataset object * [squad] add support for legacy cache files
This commit is contained in:
@@ -103,6 +103,7 @@ class SquadDataset(Dataset):
|
|||||||
mode: Union[str, Split] = Split.train,
|
mode: Union[str, Split] = Split.train,
|
||||||
is_language_sensitive: Optional[bool] = False,
|
is_language_sensitive: Optional[bool] = False,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
|
dataset_format: Optional[str] = "pt",
|
||||||
):
|
):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.is_language_sensitive = is_language_sensitive
|
self.is_language_sensitive = is_language_sensitive
|
||||||
@@ -128,28 +129,43 @@ class SquadDataset(Dataset):
|
|||||||
with FileLock(lock_path):
|
with FileLock(lock_path):
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
self.features = torch.load(cached_features_file)
|
self.old_features = torch.load(cached_features_file)
|
||||||
|
|
||||||
|
# Legacy cache files have only features, while new cache files
|
||||||
|
# will have dataset and examples also.
|
||||||
|
self.features = self.old_features["features"]
|
||||||
|
self.dataset = self.old_features.get("dataset", None)
|
||||||
|
self.examples = self.old_features.get("examples", None)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.dataset is None or self.examples is None:
|
||||||
|
logger.warn(
|
||||||
|
f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in future run"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if mode == Split.dev:
|
if mode == Split.dev:
|
||||||
examples = self.processor.get_dev_examples(args.data_dir)
|
self.examples = self.processor.get_dev_examples(args.data_dir)
|
||||||
else:
|
else:
|
||||||
examples = self.processor.get_train_examples(args.data_dir)
|
self.examples = self.processor.get_train_examples(args.data_dir)
|
||||||
|
|
||||||
self.features = squad_convert_examples_to_features(
|
self.features, self.dataset = squad_convert_examples_to_features(
|
||||||
examples=examples,
|
examples=self.examples,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
doc_stride=args.doc_stride,
|
doc_stride=args.doc_stride,
|
||||||
max_query_length=args.max_query_length,
|
max_query_length=args.max_query_length,
|
||||||
is_training=mode == Split.train,
|
is_training=mode == Split.train,
|
||||||
threads=args.threads,
|
threads=args.threads,
|
||||||
|
return_dataset=dataset_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
torch.save(self.features, cached_features_file)
|
torch.save(
|
||||||
|
{"features": self.features, "dataset": self.dataset, "examples": self.examples},
|
||||||
|
cached_features_file,
|
||||||
|
)
|
||||||
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
||||||
logger.info(
|
logger.info(
|
||||||
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
|
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
|
||||||
|
|||||||
Reference in New Issue
Block a user