typos
This commit is contained in:
@@ -116,8 +116,8 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
'attention_mask': batch[1],
|
'attention_mask': batch[1],
|
||||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||||
'labels': batch[3]}
|
'labels': batch[3]}
|
||||||
ouputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||||
|
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
|
|||||||
@@ -129,8 +129,8 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
inputs.update({'cls_index': batch[5],
|
inputs.update({'cls_index': batch[5],
|
||||||
'p_mask': batch[6]})
|
'p_mask': batch[6]})
|
||||||
ouputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||||
|
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||||
|
|||||||
Reference in New Issue
Block a user