PEP8 and formatting cleanups
This commit is contained in:
@@ -9,7 +9,7 @@ from collections import namedtuple
|
|||||||
|
|
||||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm
|
||||||
|
|
||||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
@@ -149,7 +149,8 @@ def main():
|
|||||||
help="random seed for initialization")
|
help="random seed for initialization")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert args.pregenerated_data.is_dir(), "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"
|
assert args.pregenerated_data.is_dir(), \
|
||||||
|
"--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"
|
||||||
|
|
||||||
samples_per_epoch = []
|
samples_per_epoch = []
|
||||||
for i in range(args.epochs):
|
for i in range(args.epochs):
|
||||||
@@ -237,7 +238,8 @@ def main():
|
|||||||
from apex.optimizers import FP16_Optimizer
|
from apex.optimizers import FP16_Optimizer
|
||||||
from apex.optimizers import FusedAdam
|
from apex.optimizers import FusedAdam
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
raise ImportError(
|
||||||
|
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
optimizer = FusedAdam(optimizer_grouped_parameters,
|
optimizer = FusedAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
@@ -293,7 +295,8 @@ def main():
|
|||||||
if args.fp16:
|
if args.fp16:
|
||||||
# modify learning rate with special warm up BERT uses
|
# modify learning rate with special warm up BERT uses
|
||||||
# if args.fp16 is False, BertAdam is used that handles this automatically
|
# if args.fp16 is False, BertAdam is used that handles this automatically
|
||||||
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
|
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps,
|
||||||
|
args.warmup_proportion)
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
param_group['lr'] = lr_this_step
|
param_group['lr'] = lr_this_step
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|||||||
@@ -269,6 +269,5 @@ def main():
|
|||||||
metrics_file.write(json.dumps(metrics))
|
metrics_file.write(json.dumps(metrics))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user