From 7f5367e0b18a56448dde3c4504278e57e6f4beae Mon Sep 17 00:00:00 2001 From: Marianne Stecklina Date: Thu, 19 Sep 2019 11:29:20 +0200 Subject: [PATCH] Add cli argument for configuring labels --- examples/run_ner.py | 30 +++++++++++++++--------------- examples/utils_ner.py | 11 +++++++++-- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/examples/run_ner.py b/examples/run_ner.py index ce048ade18..f51f5ae2a1 100644 --- a/examples/run_ner.py +++ b/examples/run_ner.py @@ -55,7 +55,7 @@ def set_seed(args): torch.cuda.manual_seed_all(args.seed) -def train(args, train_dataset, model, tokenizer, pad_token_label_id): +def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() @@ -148,7 +148,7 @@ def train(args, train_dataset, model, tokenizer, pad_token_label_id): if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well - results = evaluate(args, model, tokenizer, pad_token_label_id) + results = evaluate(args, model, tokenizer, labels, pad_token_label_id) for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) @@ -160,8 +160,7 @@ def train(args, train_dataset, model, tokenizer, pad_token_label_id): output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) - model_to_save = model.module if hasattr(model, - "module") else model # Take care of distributed/parallel training + model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) @@ -179,8 +178,8 @@ def train(args, train_dataset, model, tokenizer, pad_token_label_id): return global_step, tr_loss / global_step -def evaluate(args, model, tokenizer, pad_token_label_id, prefix=""): - eval_dataset = load_and_cache_examples(args, tokenizer, pad_token_label_id, evaluate=True) +def evaluate(args, model, tokenizer, labels, pad_token_label_id, prefix=""): + eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluate=True) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) # Note that DistributedSampler samples randomly @@ -220,7 +219,7 @@ def evaluate(args, model, tokenizer, pad_token_label_id, prefix=""): eval_loss = eval_loss / nb_eval_steps preds = np.argmax(preds, axis=2) - label_map = {i: label for i, label in enumerate(get_labels())} + label_map = {i: label for i, label in enumerate(labels)} out_label_list = [[] for _ in range(out_label_ids.shape[0])] preds_list = [[] for _ in range(out_label_ids.shape[0])] @@ -245,7 +244,7 @@ def evaluate(args, model, tokenizer, pad_token_label_id, prefix=""): return results -def load_and_cache_examples(args, tokenizer, pad_token_label_id, evaluate=False): +def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluate=False): if args.local_rank not in [-1, 0] and not evaluate: torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache @@ -258,9 +257,8 @@ def load_and_cache_examples(args, tokenizer, pad_token_label_id, evaluate=False) features = torch.load(cached_features_file) else: logger.info("Creating features from dataset file at %s", args.data_dir) - label_list = get_labels() examples = read_examples_from_file(args.data_dir, evaluate=evaluate) - features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, + features = convert_examples_to_features(examples, labels, args.max_seq_length, tokenizer, cls_token_at_end=bool(args.model_type in ["xlnet"]), # xlnet has a cls token at the end cls_token=tokenizer.cls_token, @@ -305,6 +303,8 @@ def main(): help="The output directory where the model predictions and checkpoints will be written.") ## Other parameters + parser.add_argument("--labels", default="", type=str, + help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.") parser.add_argument("--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name") parser.add_argument("--tokenizer_name", default="", type=str, @@ -406,8 +406,8 @@ def main(): set_seed(args) # Prepare CONLL-2003 task - label_list = get_labels() - num_labels = len(label_list) + labels = get_labels(args.labels) + num_labels = len(labels) # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later pad_token_label_id = CrossEntropyLoss().ignore_index @@ -433,8 +433,8 @@ def main(): # Training if args.do_train: - train_dataset = load_and_cache_examples(args, tokenizer, pad_token_label_id, evaluate=False) - global_step, tr_loss = train(args, train_dataset, model, tokenizer, pad_token_label_id) + train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluate=False) + global_step, tr_loss = train(args, train_dataset, model, tokenizer, labels, pad_token_label_id) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() @@ -466,7 +466,7 @@ def main(): global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" model = model_class.from_pretrained(checkpoint) model.to(args.device) - result = evaluate(args, model, tokenizer, pad_token_label_id, prefix=global_step) + result = evaluate(args, model, tokenizer, labels, pad_token_label_id, prefix=global_step) if global_step: result = {"{}_{}".format(global_step, k): v for k, v in result.items()} results.update(result) diff --git a/examples/utils_ner.py b/examples/utils_ner.py index 39f6d08149..27f76d5a59 100644 --- a/examples/utils_ner.py +++ b/examples/utils_ner.py @@ -202,5 +202,12 @@ def convert_examples_to_features(examples, return features -def get_labels(): - return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"] +def get_labels(path): + if path: + with open(path, "r") as f: + labels = f.read().splitlines() + if "O" not in labels: + labels = ["O"] + labels + return labels + else: + return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]