diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index df5b50b51d..ca973015a6 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -54,7 +54,7 @@ class LRSchedule(object): """ if self.t_total < 0: return 1. - progress = step / self.t_total + progress = float(step) / self.t_total ret = self.get_lr_(progress) # warning for exceeding t_total (only active with warmup_linear if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: