From 0cf88ff084f963261000be436cd5e3ae3dd4adb7 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 13 Dec 2018 13:28:00 +0100 Subject: [PATCH] make examples work without apex --- examples/run_classifier.py | 18 +++++++++++------- examples/run_squad.py | 18 +++++++++++------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 7c4eb7da47..aca099daf2 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -36,13 +36,6 @@ from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE -try: - from apex.optimizers import FP16_Optimizer - from apex.optimizers import FusedAdam - from apex.parallel import DistributedDataParallel as DDP -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.") - logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', level = logging.INFO) @@ -467,6 +460,11 @@ def main(): model.half() model.to(device) if args.local_rank != -1: + try: + from apex.parallel import DistributedDataParallel as DDP + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) @@ -482,6 +480,12 @@ def main(): if args.local_rank != -1: t_total = t_total // torch.distributed.get_world_size() if args.fp16: + try: + from apex.optimizers import FP16_Optimizer + from apex.optimizers import FusedAdam + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, diff --git a/examples/run_squad.py b/examples/run_squad.py index 81956ad394..147cd60f29 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -39,13 +39,6 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE -try: - from apex.optimizers import FP16_Optimizer - from apex.optimizers import FusedAdam - from apex.parallel import DistributedDataParallel as DDP -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.") - logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', level = logging.INFO) @@ -813,6 +806,11 @@ def main(): model.half() model.to(device) if args.local_rank != -1: + try: + from apex.parallel import DistributedDataParallel as DDP + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) @@ -834,6 +832,12 @@ def main(): if args.local_rank != -1: t_total = t_total // torch.distributed.get_world_size() if args.fp16: + try: + from apex.optimizers import FP16_Optimizer + from apex.optimizers import FusedAdam + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False,