From d5e60e5b7ae73f63048ab5b3ef570af22784f374 Mon Sep 17 00:00:00 2001 From: erenup Date: Tue, 20 Aug 2019 16:25:50 +0800 Subject: [PATCH] add test related code --- .../run_multiple_choice.py | 49 ++++++++++++++++--- .../utils_multiple_choice.py | 22 +++++++++ 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/examples/single_model_scripts/run_multiple_choice.py b/examples/single_model_scripts/run_multiple_choice.py index cd70084e7d..d4dd7e4b1c 100644 --- a/examples/single_model_scripts/run_multiple_choice.py +++ b/examples/single_model_scripts/run_multiple_choice.py @@ -126,6 +126,7 @@ def train(args, train_dataset, model, tokenizer): global_step = 0 tr_loss, logging_loss = 0.0, 0.0 + best_dev_acc, best_dev_loss = 0.0, 99999999999.0 model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproductibility (even between python 2 and 3) @@ -167,6 +168,13 @@ def train(args, train_dataset, model, tokenizer): results = evaluate(args, model, tokenizer) for key, value in results.items(): tb_writer.add_scalar('eval_{}'.format(key), value, global_step) + if results["eval_loss"] < best_dev_loss: + best_dev_acc = results["eval_acc"] + best_dev_loss = results["eval_loss"] + results_test = evaluate(args, model, tokenizer, test=True) + for key, value in results_test.items(): + tb_writer.add_scalar('test_{}'.format(key), value, global_step) + logger.info("test acc: %s, loss: %s, global steps: %s", str(results_test['eval_acc']), str(results_test['eval_loss']), str(global_step)) tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) logger.info("Average loss: %s at global step: %s", str((tr_loss - logging_loss)/args.logging_steps), str(global_step)) @@ -196,14 +204,14 @@ def train(args, train_dataset, model, tokenizer): return global_step, tr_loss / global_step -def evaluate(args, model, tokenizer, prefix=""): +def evaluate(args, model, tokenizer, prefix="", test=False): # Loop to handle MNLI double evaluation (matched, mis-matched) eval_task_names = (args.task_name,) eval_outputs_dirs = (args.output_dir,) results = {} for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): - eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) + eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=not test, test=test) if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: os.makedirs(eval_output_dir) @@ -251,7 +259,7 @@ def evaluate(args, model, tokenizer, prefix=""): output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: - logger.info("***** Eval results {} *****".format(prefix)) + logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test))) writer.write("model =%s\n" % str(args.model_name_or_path)) writer.write("total batch size=%d\n" % (args.per_gpu_train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))) @@ -264,14 +272,21 @@ def evaluate(args, model, tokenizer, prefix=""): return results -def load_and_cache_examples(args, task, tokenizer, evaluate=False): +def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False): if args.local_rank not in [-1, 0]: torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache processor = processors[task]() # Load data features from cache or dataset file + if evaluate: + cached_mode = 'dev' + elif test: + cached_mode = 'test' + else: + cached_mode = 'train' + assert (evaluate == True and test == True) == False cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format( - 'dev' if evaluate else 'train', + cached_mode, list(filter(None, args.model_name_or_path.split('/'))).pop(), str(args.max_seq_length), str(task))) @@ -281,7 +296,12 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): else: logger.info("Creating features from dataset file at %s", args.data_dir) label_list = processor.get_labels() - examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) + if evaluate: + examples = processor.get_dev_examples(args.data_dir) + elif test: + examples = processor.get_test_examples(args.data_dir) + else: + examples = processor.get_train_examples(args.data_dir) logger.info("Training number: %s", str(len(examples))) features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end @@ -337,6 +357,7 @@ def main(): help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") + parser.add_argument("--do_test", action='store_true', help='Whether to run test on the test set') parser.add_argument("--evaluate_during_training", action='store_true', help="Rul evaluation during training at each logging step.") parser.add_argument("--do_lower_case", action='store_true', @@ -494,6 +515,22 @@ def main(): result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) results.update(result) + if args.do_test and args.local_rank in [-1, 0]: + if not args.do_train: + args.output_dir = args.model_name_or_path + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging + logger.info("Evaluate the following checkpoints: %s", checkpoints) + for checkpoint in checkpoints: + global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" + model = model_class.from_pretrained(checkpoint) + model.to(args.device) + result = evaluate(args, model, tokenizer, prefix=global_step, test=True) + result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) + results.update(result) + return results diff --git a/examples/single_model_scripts/utils_multiple_choice.py b/examples/single_model_scripts/utils_multiple_choice.py index 6ecb8e0f55..d33477a30d 100644 --- a/examples/single_model_scripts/utils_multiple_choice.py +++ b/examples/single_model_scripts/utils_multiple_choice.py @@ -83,6 +83,10 @@ class DataProcessor(object): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() + def get_test_examples(self, data_dir): + """Gets a collection of `InputExample`s for the dev set.""" + raise NotImplementedError() + def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @@ -109,6 +113,15 @@ class RaceProcessor(DataProcessor): middle = self._read_txt(middle) return self._create_examples(high + middle, 'dev') + def get_test_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} test".format(data_dir)) + high = os.path.join(data_dir, 'test/high') + middle = os.path.join(data_dir, 'test/middle') + high = self._read_txt(high) + middle = self._read_txt(middle) + return self._create_examples(high + middle, 'test') + def get_labels(self): """See base class.""" return ["0", "1", "2", "3"] @@ -157,6 +170,11 @@ class SwagProcessor(DataProcessor): logger.info("LOOKING AT {} dev".format(data_dir)) return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev") + def get_test_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {} test".format(data_dir)) + return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test") + def get_labels(self): """See base class.""" return ["0", "1", "2", "3"] @@ -207,6 +225,10 @@ class ArcProcessor(DataProcessor): logger.info("LOOKING AT {} dev".format(data_dir)) return self._create_examples(self._read_json(os.path.join(data_dir, "dev.jsonl")), "dev") + def get_test_examples(self, data_dir): + logger.info("LOOKING AT {} test".format(data_dir)) + return self._create_examples(self._read_json(os.path.join(data_dir, "test.jsonl")), "test") + def get_labels(self): """See base class.""" return ["0", "1", "2", "3"]