diff --git a/run_classifier.py b/run_classifier.py index 58eb039d93..2c82bcd4c1 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -426,7 +426,7 @@ def main(): raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( args.accumulate_gradients)) - args.train_batch_size = args.train_batch_size / args.accumulate_gradients + args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients) random.seed(args.seed) np.random.seed(args.seed) diff --git a/run_squad.py b/run_squad.py index 868dc99a23..3a961079c9 100644 --- a/run_squad.py +++ b/run_squad.py @@ -756,7 +756,7 @@ def main(): raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( args.accumulate_gradients)) - args.train_batch_size = args.train_batch_size / args.accumulate_gradients + args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients) random.seed(args.seed) np.random.seed(args.seed)