From 4e6a55751a510c50347226653df68b07a9caa8c7 Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Fri, 13 Sep 2019 15:21:40 -0400 Subject: [PATCH] Force einsum to fp16 --- examples/run_squad.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 43b65d2c3c..71c656a13d 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -138,8 +138,8 @@ def train(args, train_dataset, model, tokenizer): model.train() batch = tuple(t.to(args.device) for t in batch) inputs = {'input_ids': batch[0], - 'attention_mask': batch[1], - 'start_positions': batch[3], + 'attention_mask': batch[1], + 'start_positions': batch[3], 'end_positions': batch[4]} if args.model_type != 'distilbert': inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] @@ -481,6 +481,16 @@ def main(): logger.info("Training/evaluation parameters %s", args) + # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. + # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will + # remove the need for this code, but it is still valid. + if args.fp16: + try: + import apex + apex.amp.register_half_function(torch, 'einsum') + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + # Training if args.do_train: train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)