💄 super

This commit is contained in:
Julien Chaumond
2020-01-15 18:33:50 -05:00
parent cd51893d37
commit 83a41d39b3
75 changed files with 328 additions and 328 deletions

View File

@@ -24,7 +24,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applys a warmup schedule on a given learning rate decay schedule."""
def __init__(self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None):
super(WarmUp, self).__init__()
super().__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.power = power
@@ -102,7 +102,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
name="AdamWeightDecay",
**kwargs
):
super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate
self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay
@@ -111,10 +111,10 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def from_config(cls, config):
"""Creates an optimizer from its config with WarmUp custom object."""
custom_objects = {"WarmUp": WarmUp}
return super(AdamWeightDecay, cls).from_config(config, custom_objects=custom_objects)
return super().from_config(config, custom_objects=custom_objects)
def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, apply_state)
super()._prepare_local(var_device, var_dtype, apply_state)
apply_state["weight_decay_rate"] = tf.constant(self.weight_decay_rate, name="adam_weight_decay_rate")
def _decay_weights_op(self, var, learning_rate, apply_state):
@@ -128,7 +128,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def apply_gradients(self, grads_and_vars, clip_norm, name=None):
grads, tvars = list(zip(*grads_and_vars))
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
return super().apply_gradients(zip(grads, tvars))
def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state."""
@@ -147,16 +147,16 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_dense(grad, var, **kwargs)
return super()._resource_apply_dense(grad, var, **kwargs)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_sparse(grad, var, indices, **kwargs)
return super()._resource_apply_sparse(grad, var, indices, **kwargs)
def get_config(self):
config = super(AdamWeightDecay, self).get_config()
config = super().get_config()
config.update({"weight_decay_rate": self.weight_decay_rate})
return config