[squad] add version tag to squad cache (#5669)
This commit is contained in:
@@ -113,9 +113,12 @@ class SquadDataset(Dataset):
|
|||||||
raise KeyError("mode is not a valid split name")
|
raise KeyError("mode is not a valid split name")
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
|
version_tag = "v2" if args.version_2_with_negative else "v1"
|
||||||
cached_features_file = os.path.join(
|
cached_features_file = os.path.join(
|
||||||
cache_dir if cache_dir is not None else args.data_dir,
|
cache_dir if cache_dir is not None else args.data_dir,
|
||||||
"cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(args.max_seq_length),),
|
"cached_{}_{}_{}_{}".format(
|
||||||
|
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), version_tag,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure only the first process in distributed training processes the dataset,
|
# Make sure only the first process in distributed training processes the dataset,
|
||||||
|
|||||||
Reference in New Issue
Block a user