[squad] add version tag to squad cache (#5669)
This commit is contained in:
@@ -113,9 +113,12 @@ class SquadDataset(Dataset):
|
||||
raise KeyError("mode is not a valid split name")
|
||||
self.mode = mode
|
||||
# Load data features from cache or dataset file
|
||||
version_tag = "v2" if args.version_2_with_negative else "v1"
|
||||
cached_features_file = os.path.join(
|
||||
cache_dir if cache_dir is not None else args.data_dir,
|
||||
"cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(args.max_seq_length),),
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), version_tag,
|
||||
),
|
||||
)
|
||||
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
|
||||
Reference in New Issue
Block a user