fixing learning rate schedule when using gradient_accumulation_steps

This commit is contained in:
thomwolf
2018-11-10 16:11:14 +01:00
parent ea85cca8ab
commit a81a1ef8e9
2 changed files with 18 additions and 2 deletions

View File

@@ -464,7 +464,7 @@ def main():
if args.do_train:
train_examples = processor.get_train_examples(args.data_dir)
num_train_steps = int(
len(train_examples) / args.train_batch_size * args.num_train_epochs)
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
model = BertForSequenceClassification(bert_config, len(label_list))
if args.init_checkpoint is not None: