Merge pull request #1513 from slayton58/amp_fp16_einsum
Force einsum to run in fp16
This commit is contained in:
@@ -481,6 +481,16 @@ def main():
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||
# remove the need for this code, but it is still valid.
|
||||
if args.fp16:
|
||||
try:
|
||||
import apex
|
||||
apex.amp.register_half_function(torch, 'einsum')
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
||||
|
||||
Reference in New Issue
Block a user