From 9e136ff57c268bfe0c0bfb322401684aaba12d15 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 4 Oct 2019 15:00:56 -0400 Subject: [PATCH] Honor args.overwrite_cache (h/t @erenup) --- examples/run_glue.py | 2 +- examples/run_multiple_choice.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index fc3b617da0..3a6eac63ad 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -270,7 +270,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): list(filter(None, args.model_name_or_path.split('/'))).pop(), str(args.max_seq_length), str(task))) - if os.path.exists(cached_features_file): + if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) else: diff --git a/examples/run_multiple_choice.py b/examples/run_multiple_choice.py index 2ce1327c98..223316cbcb 100644 --- a/examples/run_multiple_choice.py +++ b/examples/run_multiple_choice.py @@ -293,7 +293,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False): list(filter(None, args.model_name_or_path.split('/'))).pop(), str(args.max_seq_length), str(task))) - if os.path.exists(cached_features_file): + if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) else: