Minor bug fixes on run_ner.py

This commit is contained in:
manansanghi
2019-11-22 10:31:54 -08:00
committed by Lysandre Debut
parent aa92a184d2
commit 5d3b8daad2

View File

@@ -127,7 +127,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
"attention_mask": batch[1],
"labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
@@ -217,7 +217,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
"attention_mask": batch[1],
"labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]