[squad] make examples and dataset accessible from SquadDataset object (#6710)
* [squad] make examples and dataset accessible from SquadDataset object * [squad] add support for legacy cache files
This commit is contained in:
@@ -103,6 +103,7 @@ class SquadDataset(Dataset):
|
||||
mode: Union[str, Split] = Split.train,
|
||||
is_language_sensitive: Optional[bool] = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
dataset_format: Optional[str] = "pt",
|
||||
):
|
||||
self.args = args
|
||||
self.is_language_sensitive = is_language_sensitive
|
||||
@@ -128,28 +129,43 @@ class SquadDataset(Dataset):
|
||||
with FileLock(lock_path):
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
start = time.time()
|
||||
self.features = torch.load(cached_features_file)
|
||||
self.old_features = torch.load(cached_features_file)
|
||||
|
||||
# Legacy cache files have only features, while new cache files
|
||||
# will have dataset and examples also.
|
||||
self.features = self.old_features["features"]
|
||||
self.dataset = self.old_features.get("dataset", None)
|
||||
self.examples = self.old_features.get("examples", None)
|
||||
logger.info(
|
||||
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
||||
)
|
||||
|
||||
if self.dataset is None or self.examples is None:
|
||||
logger.warn(
|
||||
f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in future run"
|
||||
)
|
||||
else:
|
||||
if mode == Split.dev:
|
||||
examples = self.processor.get_dev_examples(args.data_dir)
|
||||
self.examples = self.processor.get_dev_examples(args.data_dir)
|
||||
else:
|
||||
examples = self.processor.get_train_examples(args.data_dir)
|
||||
self.examples = self.processor.get_train_examples(args.data_dir)
|
||||
|
||||
self.features = squad_convert_examples_to_features(
|
||||
examples=examples,
|
||||
self.features, self.dataset = squad_convert_examples_to_features(
|
||||
examples=self.examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=mode == Split.train,
|
||||
threads=args.threads,
|
||||
return_dataset=dataset_format,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
torch.save(self.features, cached_features_file)
|
||||
torch.save(
|
||||
{"features": self.features, "dataset": self.dataset, "examples": self.examples},
|
||||
cached_features_file,
|
||||
)
|
||||
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
||||
logger.info(
|
||||
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
|
||||
|
||||
Reference in New Issue
Block a user