diff --git a/pytorch_pretrained_bert/optimization.py b/pytorch_pretrained_bert/optimization.py index cea35c39e9..84f329feae 100644 --- a/pytorch_pretrained_bert/optimization.py +++ b/pytorch_pretrained_bert/optimization.py @@ -118,7 +118,7 @@ class BertAdam(Optimizer): max_grad_norm=1.0): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) - if schedule not in SCHEDULES: + if not isinstance(schedule, LRSchedule) and schedule not in SCHEDULES: raise ValueError("Invalid schedule parameter: {}".format(schedule)) if not 0.0 <= b1 < 1.0: raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))