This commit is contained in:
Peiqin Lin
2019-07-21 20:39:42 +08:00
parent a615499076
commit 76be189b08
2 changed files with 4 additions and 4 deletions

View File

@@ -116,8 +116,8 @@ def train(args, train_dataset, model, tokenizer):
'attention_mask': batch[1], 'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'labels': batch[3]} 'labels': batch[3]}
ouputs = model(**inputs) outputs = model(**inputs)
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1: if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training loss = loss.mean() # mean() to average on multi-gpu parallel training

View File

@@ -129,8 +129,8 @@ def train(args, train_dataset, model, tokenizer):
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5], inputs.update({'cls_index': batch[5],
'p_mask': batch[6]}) 'p_mask': batch[6]})
ouputs = model(**inputs) outputs = model(**inputs)
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1: if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training