From 71d597dad0a28ccc397308146844486e0031d701 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 17 Jul 2019 13:51:09 +0200 Subject: [PATCH] fix #800 --- examples/run_squad.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index e920ebe378..d72d67b87d 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -122,9 +122,9 @@ def train(args, train_dataset, model, tokenizer): model.train() batch = tuple(t.to(args.device) for t in batch) inputs = {'input_ids': batch[0], - 'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids - 'attention_mask': batch[2], - 'start_positions': batch[3], + 'attention_mask': batch[1], + 'token_type_ids': None if args.model_type == 'xlm' else batch[2], + 'start_positions': batch[3], 'end_positions': batch[4]} if args.model_type in ['xlnet', 'xlm']: inputs.update({'cls_index': batch[5], @@ -206,8 +206,9 @@ def evaluate(args, model, tokenizer, prefix=""): batch = tuple(t.to(args.device) for t in batch) with torch.no_grad(): inputs = {'input_ids': batch[0], - 'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids - 'attention_mask': batch[2]} + 'attention_mask': batch[1], + 'token_type_ids': None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids + } example_indices = batch[3] if args.model_type in ['xlnet', 'xlm']: inputs.update({'cls_index': batch[4],