This commit is contained in:
Peiqin Lin
2019-07-21 20:39:42 +08:00
parent a615499076
commit 76be189b08
2 changed files with 4 additions and 4 deletions

View File

@@ -116,8 +116,8 @@ def train(args, train_dataset, model, tokenizer):
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'labels': batch[3]}
ouputs = model(**inputs)
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training

View File

@@ -129,8 +129,8 @@ def train(args, train_dataset, model, tokenizer):
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5],
'p_mask': batch[6]})
ouputs = model(**inputs)
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training