From 5d3b8daad2cc6287d30f03f8a96d0a1f7bc8d0dc Mon Sep 17 00:00:00 2001 From: manansanghi <52307004+manansanghi@users.noreply.github.com> Date: Fri, 22 Nov 2019 10:31:54 -0800 Subject: [PATCH] Minor bug fixes on run_ner.py --- examples/run_ner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/run_ner.py b/examples/run_ner.py index 127d63a6cd..1ab1236d94 100644 --- a/examples/run_ner.py +++ b/examples/run_ner.py @@ -127,7 +127,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": - inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids + inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) @@ -217,7 +217,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix="" "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": - inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids + inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids outputs = model(**inputs) tmp_eval_loss, logits = outputs[:2]