make examples work without apex
This commit is contained in:
@@ -36,13 +36,6 @@ from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
|||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
|
|
||||||
try:
|
|
||||||
from apex.optimizers import FP16_Optimizer
|
|
||||||
from apex.optimizers import FusedAdam
|
|
||||||
from apex.parallel import DistributedDataParallel as DDP
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.")
|
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
level = logging.INFO)
|
level = logging.INFO)
|
||||||
@@ -467,6 +460,11 @@ def main():
|
|||||||
model.half()
|
model.half()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
|
try:
|
||||||
|
from apex.parallel import DistributedDataParallel as DDP
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
model = DDP(model)
|
model = DDP(model)
|
||||||
elif n_gpu > 1:
|
elif n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
@@ -482,6 +480,12 @@ def main():
|
|||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
t_total = t_total // torch.distributed.get_world_size()
|
t_total = t_total // torch.distributed.get_world_size()
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
|
try:
|
||||||
|
from apex.optimizers import FP16_Optimizer
|
||||||
|
from apex.optimizers import FusedAdam
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
bias_correction=False,
|
bias_correction=False,
|
||||||
|
|||||||
@@ -39,13 +39,6 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
|||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
|
|
||||||
try:
|
|
||||||
from apex.optimizers import FP16_Optimizer
|
|
||||||
from apex.optimizers import FusedAdam
|
|
||||||
from apex.parallel import DistributedDataParallel as DDP
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.")
|
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
level = logging.INFO)
|
level = logging.INFO)
|
||||||
@@ -813,6 +806,11 @@ def main():
|
|||||||
model.half()
|
model.half()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
|
try:
|
||||||
|
from apex.parallel import DistributedDataParallel as DDP
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
model = DDP(model)
|
model = DDP(model)
|
||||||
elif n_gpu > 1:
|
elif n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
@@ -834,6 +832,12 @@ def main():
|
|||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
t_total = t_total // torch.distributed.get_world_size()
|
t_total = t_total // torch.distributed.get_world_size()
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
|
try:
|
||||||
|
from apex.optimizers import FP16_Optimizer
|
||||||
|
from apex.optimizers import FusedAdam
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
bias_correction=False,
|
bias_correction=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user