fix #800
This commit is contained in:
@@ -122,9 +122,9 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
model.train()
|
model.train()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids
|
'attention_mask': batch[1],
|
||||||
'attention_mask': batch[2],
|
'token_type_ids': None if args.model_type == 'xlm' else batch[2],
|
||||||
'start_positions': batch[3],
|
'start_positions': batch[3],
|
||||||
'end_positions': batch[4]}
|
'end_positions': batch[4]}
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
inputs.update({'cls_index': batch[5],
|
inputs.update({'cls_index': batch[5],
|
||||||
@@ -206,8 +206,9 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids
|
'attention_mask': batch[1],
|
||||||
'attention_mask': batch[2]}
|
'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||||
|
}
|
||||||
example_indices = batch[3]
|
example_indices = batch[3]
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
inputs.update({'cls_index': batch[4],
|
inputs.update({'cls_index': batch[4],
|
||||||
|
|||||||
Reference in New Issue
Block a user