This commit is contained in:
Peiqin Lin
2019-07-21 20:39:42 +08:00
parent a615499076
commit 76be189b08
2 changed files with 4 additions and 4 deletions

View File

@@ -129,8 +129,8 @@ def train(args, train_dataset, model, tokenizer):
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5],
'p_mask': batch[6]})
ouputs = model(**inputs)
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training