added warning
This commit is contained in:
@@ -159,8 +159,8 @@ class BertAdam(Optimizer):
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
|
||||
logger.warning(
|
||||
"Training beyond specified 't_total' steps. Learning rate set to {}. "
|
||||
"Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
|
||||
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
|
||||
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
|
||||
warned_for_t_total = True
|
||||
# end warning
|
||||
else:
|
||||
|
||||
@@ -29,14 +29,14 @@ def warmup_cosine(x, warmup=0.002):
|
||||
return 0.5 * (1.0 + torch.cos(math.pi * x))
|
||||
|
||||
def warmup_constant(x, warmup=0.002):
|
||||
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
|
||||
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
|
||||
Learning rate is 1. afterwards. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
|
||||
After `t_total`-th training step, learning rate is zero. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
@@ -142,8 +142,8 @@ class OpenAIAdam(Optimizer):
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
|
||||
logger.warning(
|
||||
"Training beyond specified 't_total' steps. Learning rate set to {}. "
|
||||
"Please set 't_total' of {} correctly.".format(lr_scheduled, self.__class__.__name__))
|
||||
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
|
||||
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
|
||||
warned_for_t_total = True
|
||||
# end warning
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user