Merge pull request #1155 from anhnt170489/apex_fp16
Update apex fp16 implementation
This commit is contained in:
@@ -235,8 +235,9 @@ def main():
|
|||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForPreTraining.from_pretrained(args.bert_model)
|
model = BertForPreTraining.from_pretrained(args.bert_model)
|
||||||
if args.fp16:
|
# We don't need to manually call model.half() following Apex's recommend
|
||||||
model.half()
|
# if args.fp16:
|
||||||
|
# model.half()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
try:
|
try:
|
||||||
@@ -257,25 +258,36 @@ def main():
|
|||||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
|
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps,
|
||||||
|
t_total=num_train_optimization_steps)
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex.optimizers import FP16_Optimizer
|
# from apex.optimizers import FP16_Optimizer
|
||||||
from apex.optimizers import FusedAdam
|
# from apex.optimizers import FusedAdam
|
||||||
|
from apex import amp
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
# This below line of code is the main upgrade of Apex Fp16 implementation. I chose opt_leve="01"
|
||||||
lr=args.learning_rate,
|
# because it's recommended for typical use by Apex. We can make it configured
|
||||||
bias_correction=False,
|
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
|
||||||
max_grad_norm=1.0)
|
|
||||||
if args.loss_scale == 0:
|
# We don't need to use FP16_Optimizer wrapping over FusedAdam as well. Now Apex supports all Pytorch Optimizer
|
||||||
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
|
||||||
else:
|
# optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||||
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
# lr=args.learning_rate,
|
||||||
else:
|
# bias_correction=False,
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
# max_grad_norm=1.0)
|
||||||
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
|
# if args.loss_scale == 0:
|
||||||
|
# optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
||||||
|
# else:
|
||||||
|
# optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
|
||||||
|
# else:
|
||||||
|
# optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
|
# scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
logging.info("***** Running training *****")
|
logging.info("***** Running training *****")
|
||||||
@@ -304,7 +316,10 @@ def main():
|
|||||||
if args.gradient_accumulation_steps > 1:
|
if args.gradient_accumulation_steps > 1:
|
||||||
loss = loss / args.gradient_accumulation_steps
|
loss = loss / args.gradient_accumulation_steps
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
optimizer.backward(loss)
|
# I depricate FP16_Optimizer's backward func and replace as Apex document
|
||||||
|
# optimizer.backward(loss)
|
||||||
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
|
|||||||
Reference in New Issue
Block a user