From 702f589848baba97ea4897aa3f0bb937e1ec3bcf Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 27 Sep 2019 00:20:14 -0400 Subject: [PATCH] fix input in run_glue for distilbert --- examples/run_glue.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/run_glue.py b/examples/run_glue.py index 71dad0edbf..fc3b617da0 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -134,8 +134,9 @@ def train(args, train_dataset, model, tokenizer): batch = tuple(t.to(args.device) for t in batch) inputs = {'input_ids': batch[0], 'attention_mask': batch[1], - 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids 'labels': batch[3]} + if args.model_type != 'distilbert': + inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in transformers (see doc) @@ -224,8 +225,9 @@ def evaluate(args, model, tokenizer, prefix=""): with torch.no_grad(): inputs = {'input_ids': batch[0], 'attention_mask': batch[1], - 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids 'labels': batch[3]} + if args.model_type != 'distilbert': + inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids outputs = model(**inputs) tmp_eval_loss, logits = outputs[:2]