💄 super
This commit is contained in:
@@ -24,7 +24,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||
"""Applys a warmup schedule on a given learning rate decay schedule."""
|
||||
|
||||
def __init__(self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None):
|
||||
super(WarmUp, self).__init__()
|
||||
super().__init__()
|
||||
self.initial_learning_rate = initial_learning_rate
|
||||
self.warmup_steps = warmup_steps
|
||||
self.power = power
|
||||
@@ -102,7 +102,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
name="AdamWeightDecay",
|
||||
**kwargs
|
||||
):
|
||||
super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
|
||||
super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
|
||||
self.weight_decay_rate = weight_decay_rate
|
||||
self._include_in_weight_decay = include_in_weight_decay
|
||||
self._exclude_from_weight_decay = exclude_from_weight_decay
|
||||
@@ -111,10 +111,10 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
def from_config(cls, config):
|
||||
"""Creates an optimizer from its config with WarmUp custom object."""
|
||||
custom_objects = {"WarmUp": WarmUp}
|
||||
return super(AdamWeightDecay, cls).from_config(config, custom_objects=custom_objects)
|
||||
return super().from_config(config, custom_objects=custom_objects)
|
||||
|
||||
def _prepare_local(self, var_device, var_dtype, apply_state):
|
||||
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, apply_state)
|
||||
super()._prepare_local(var_device, var_dtype, apply_state)
|
||||
apply_state["weight_decay_rate"] = tf.constant(self.weight_decay_rate, name="adam_weight_decay_rate")
|
||||
|
||||
def _decay_weights_op(self, var, learning_rate, apply_state):
|
||||
@@ -128,7 +128,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
def apply_gradients(self, grads_and_vars, clip_norm, name=None):
|
||||
grads, tvars = list(zip(*grads_and_vars))
|
||||
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
|
||||
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
|
||||
return super().apply_gradients(zip(grads, tvars))
|
||||
|
||||
def _get_lr(self, var_device, var_dtype, apply_state):
|
||||
"""Retrieves the learning rate with the given state."""
|
||||
@@ -147,16 +147,16 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
||||
decay = self._decay_weights_op(var, lr_t, apply_state)
|
||||
with tf.control_dependencies([decay]):
|
||||
return super(AdamWeightDecay, self)._resource_apply_dense(grad, var, **kwargs)
|
||||
return super()._resource_apply_dense(grad, var, **kwargs)
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
||||
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
||||
decay = self._decay_weights_op(var, lr_t, apply_state)
|
||||
with tf.control_dependencies([decay]):
|
||||
return super(AdamWeightDecay, self)._resource_apply_sparse(grad, var, indices, **kwargs)
|
||||
return super()._resource_apply_sparse(grad, var, indices, **kwargs)
|
||||
|
||||
def get_config(self):
|
||||
config = super(AdamWeightDecay, self).get_config()
|
||||
config = super().get_config()
|
||||
config.update({"weight_decay_rate": self.weight_decay_rate})
|
||||
return config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user