working on automodels

This commit is contained in:
thomwolf
2019-08-05 16:06:34 +02:00
parent 58830807d1
commit b90e29d52c
6 changed files with 289 additions and 14 deletions

View File

@@ -134,7 +134,7 @@ def train(args, train_dataset, model, tokenizer):
'end_positions': batch[4]}
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5],
'p_mask': batch[6]})
'p_mask': batch[6]})
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)