fix input in run_glue for distilbert
This commit is contained in:
@@ -134,8 +134,9 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'attention_mask': batch[1],
|
'attention_mask': batch[1],
|
||||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids
|
|
||||||
'labels': batch[3]}
|
'labels': batch[3]}
|
||||||
|
if args.model_type != 'distilbert':
|
||||||
|
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||||
|
|
||||||
@@ -224,8 +225,9 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'attention_mask': batch[1],
|
'attention_mask': batch[1],
|
||||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM, DistilBERT and RoBERTa don't use segment_ids
|
|
||||||
'labels': batch[3]}
|
'labels': batch[3]}
|
||||||
|
if args.model_type != 'distilbert':
|
||||||
|
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
tmp_eval_loss, logits = outputs[:2]
|
tmp_eval_loss, logits = outputs[:2]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user