Merge pull request #1513 from slayton58/amp_fp16_einsum
Force einsum to run in fp16
This commit is contained in:
@@ -481,6 +481,16 @@ def main():
|
|||||||
|
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
|
||||||
|
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||||
|
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||||
|
# remove the need for this code, but it is still valid.
|
||||||
|
if args.fp16:
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
apex.amp.register_half_function(torch, 'einsum')
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user