From aad3a54e9ce1bdc1bcb9309e3ebdea03dbeee588 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 16:48:04 +0200 Subject: [PATCH] fix paths --- examples/run_classifier.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 0add05113f..d945b0dfc8 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -380,24 +380,22 @@ def main(): ### Evaluation if args.do_eval: eval_examples = processor.get_dev_examples(args.data_dir) - cached_train_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format( + cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format( list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(task_name))) try: - with open(cached_train_features_file, "rb") as reader: + with open(cached_eval_features_file, "rb") as reader: train_features = pickle.load(reader) except: - train_features = convert_examples_to_features( - train_examples, label_list, args.max_seq_length, tokenizer, output_mode) + eval_features = convert_examples_to_features( + eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) if args.local_rank == -1 or torch.distributed.get_rank() == 0: - logger.info(" Saving train features into cached file %s", cached_train_features_file) - with open(cached_train_features_file, "wb") as writer: - pickle.dump(train_features, writer) + logger.info(" Saving eval features into cached file %s", cached_eval_features_file) + with open(cached_eval_features_file, "wb") as writer: + pickle.dump(eval_features, writer) - eval_features = convert_examples_to_features( - eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size)