Force einsum to fp16
This commit is contained in:
@@ -138,8 +138,8 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
model.train()
|
model.train()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'attention_mask': batch[1],
|
'attention_mask': batch[1],
|
||||||
'start_positions': batch[3],
|
'start_positions': batch[3],
|
||||||
'end_positions': batch[4]}
|
'end_positions': batch[4]}
|
||||||
if args.model_type != 'distilbert':
|
if args.model_type != 'distilbert':
|
||||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
|
||||||
@@ -481,6 +481,16 @@ def main():
|
|||||||
|
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
|
||||||
|
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||||
|
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||||
|
# remove the need for this code, but it is still valid.
|
||||||
|
if args.fp16:
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
apex.amp.register_half_function(torch, 'einsum')
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user