Minor bug fixes on run_ner.py
This commit is contained in:
committed by
Lysandre Debut
parent
aa92a184d2
commit
5d3b8daad2
@@ -127,7 +127,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||
@@ -217,7 +217,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
|
||||
"attention_mask": batch[1],
|
||||
"labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user