make examples consistent, revert error in num_train_steps calculation
This commit is contained in:
@@ -757,7 +757,7 @@ def main():
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
|
||||
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
@@ -788,8 +788,8 @@ def main():
|
||||
if args.do_train:
|
||||
train_examples = read_squad_examples(
|
||||
input_file=args.train_file, is_training=True)
|
||||
num_train_steps = int(
|
||||
len(train_examples) / args.train_batch_size * args.num_train_epochs)
|
||||
num_train_steps =
|
||||
len(train_examples) // args.train_batch_size // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare model
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
|
||||
|
||||
Reference in New Issue
Block a user