[squad] add version tag to squad cache (#5669)

This commit is contained in:
Tomo Lazovich
2020-07-10 16:34:21 -04:00
committed by GitHub
parent 223084e42b
commit cdf4cd7068

View File

@@ -113,9 +113,12 @@ class SquadDataset(Dataset):
raise KeyError("mode is not a valid split name")
self.mode = mode
# Load data features from cache or dataset file
version_tag = "v2" if args.version_2_with_negative else "v1"
cached_features_file = os.path.join(
cache_dir if cache_dir is not None else args.data_dir,
"cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(args.max_seq_length),),
"cached_{}_{}_{}_{}".format(
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), version_tag,
),
)
# Make sure only the first process in distributed training processes the dataset,