From 5adb39e757183a00b946d3b0571e1983fd0e26b7 Mon Sep 17 00:00:00 2001 From: Marianne Stecklina Date: Mon, 23 Sep 2019 10:51:54 +0200 Subject: [PATCH] Add option to predict on test set --- examples/run_ner.py | 46 ++++++++++++++++++++++++++++++++++--------- examples/utils_ner.py | 19 +++++++++--------- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/examples/run_ner.py b/examples/run_ner.py index f51f5ae2a1..6c6b0f8336 100644 --- a/examples/run_ner.py +++ b/examples/run_ner.py @@ -148,7 +148,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well - results = evaluate(args, model, tokenizer, labels, pad_token_label_id) + results, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id) for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) @@ -178,8 +178,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): return global_step, tr_loss / global_step -def evaluate(args, model, tokenizer, labels, pad_token_label_id, prefix=""): - eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluate=True) +def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""): + eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode=mode) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) # Note that DistributedSampler samples randomly @@ -241,15 +241,15 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, prefix=""): for key in sorted(results.keys()): logger.info(" %s = %s", key, str(results[key])) - return results + return results, preds_list -def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluate=False): +def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode): if args.local_rank not in [-1, 0] and not evaluate: torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache # Load data features from cache or dataset file - cached_features_file = os.path.join(args.data_dir, "cached_{}_{}_{}".format("dev" if evaluate else "train", + cached_features_file = os.path.join(args.data_dir, "cached_{}_{}_{}".format(mode, list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length))) if os.path.exists(cached_features_file): @@ -257,7 +257,7 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluat features = torch.load(cached_features_file) else: logger.info("Creating features from dataset file at %s", args.data_dir) - examples = read_examples_from_file(args.data_dir, evaluate=evaluate) + examples = read_examples_from_file(args.data_dir, mode) features = convert_examples_to_features(examples, labels, args.max_seq_length, tokenizer, cls_token_at_end=bool(args.model_type in ["xlnet"]), # xlnet has a cls token at the end @@ -318,6 +318,8 @@ def main(): help="Whether to run training.") parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") + parser.add_argument("--do_predict", action="store_true", + help="Whether to run predictions on the test set.") parser.add_argument("--evaluate_during_training", action="store_true", help="Whether to run evaluation during training at each logging step.") parser.add_argument("--do_lower_case", action="store_true", @@ -433,7 +435,7 @@ def main(): # Training if args.do_train: - train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluate=False) + train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="train") global_step, tr_loss = train(args, train_dataset, model, tokenizer, labels, pad_token_label_id) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) @@ -466,7 +468,7 @@ def main(): global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" model = model_class.from_pretrained(checkpoint) model.to(args.device) - result = evaluate(args, model, tokenizer, labels, pad_token_label_id, prefix=global_step) + result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step) if global_step: result = {"{}_{}".format(global_step, k): v for k, v in result.items()} results.update(result) @@ -475,6 +477,32 @@ def main(): for key in sorted(results.keys()): writer.write("{} = {}\n".format(key, str(results[key]))) + if args.do_predict and args.local_rank in [-1, 0]: + tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + model = model_class.from_pretrained(args.output_dir) + model.to(args.device) + result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test") + # Save results + output_test_results_file = os.path.join(args.output_dir, "test_results.txt") + with open(output_test_results_file, "w") as writer: + for key in sorted(result.keys()): + writer.write("{} = {}\n".format(key, str(result[key]))) + # Save predictions + output_test_predictions_file = os.path.join(args.output_dir, "test_predictions.txt") + with open(output_test_predictions_file, "w") as writer: + with open(os.path.join(args.data_dir, "test.txt"), "r") as f: + example_id = 0 + for line in f: + if line.startswith("-DOCSTART-") or line == "" or line == "\n": + writer.write(line) + if not predictions[example_id]: + example_id += 1 + elif predictions[example_id]: + output_line = line.split()[0] + " " + predictions[example_id].pop(0) + "\n" + writer.write(output_line) + else: + logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]) + return results diff --git a/examples/utils_ner.py b/examples/utils_ner.py index 27f76d5a59..c20d7b0d1f 100644 --- a/examples/utils_ner.py +++ b/examples/utils_ner.py @@ -51,13 +51,8 @@ class InputFeatures(object): self.label_ids = label_ids -def read_examples_from_file(data_dir, evaluate=False): - if evaluate: - file_path = os.path.join(data_dir, "dev.txt") - guid_prefix = "dev" - else: - file_path = os.path.join(data_dir, "train.txt") - guid_prefix = "train" +def read_examples_from_file(data_dir, mode): + file_path = os.path.join(data_dir, "{}.txt".format(mode)) guid_index = 1 examples = [] with open(file_path, encoding="utf-8") as f: @@ -66,7 +61,7 @@ def read_examples_from_file(data_dir, evaluate=False): for line in f: if line.startswith("-DOCSTART-") or line == "" or line == "\n": if words: - examples.append(InputExample(guid="{}-{}".format(guid_prefix, guid_index), + examples.append(InputExample(guid="{}-{}".format(mode, guid_index), words=words, labels=labels)) guid_index += 1 @@ -75,9 +70,13 @@ def read_examples_from_file(data_dir, evaluate=False): else: splits = line.split(" ") words.append(splits[0]) - labels.append(splits[-1].replace("\n", "")) + if len(splits) > 1: + labels.append(splits[-1].replace("\n", "")) + else: + # Examples could have no label for mode = "test" + labels.append("O") if words: - examples.append(InputExample(guid="%s-%d".format(guid_prefix, guid_index), + examples.append(InputExample(guid="%s-%d".format(mode, guid_index), words=words, labels=labels)) return examples