Adafactor: avoid updating group["lr"] attributes (#9751)
This affects Adafactor with relative_step=False and scale_parameter=True. Updating group["lr"] makes the result of ._get_lr() depends on the previous call, i.e., on the scale of other parameters. This isn't supposed to happen.
This commit is contained in:
@@ -546,7 +546,7 @@ class Adafactor(Optimizer):
|
|||||||
|
|
||||||
state["step"] += 1
|
state["step"] += 1
|
||||||
state["RMS"] = self._rms(p_data_fp32)
|
state["RMS"] = self._rms(p_data_fp32)
|
||||||
group["lr"] = self._get_lr(group, state)
|
lr = self._get_lr(group, state)
|
||||||
|
|
||||||
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
||||||
update = (grad ** 2) + group["eps"][0]
|
update = (grad ** 2) + group["eps"][0]
|
||||||
@@ -567,7 +567,7 @@ class Adafactor(Optimizer):
|
|||||||
update = exp_avg_sq.rsqrt().mul_(grad)
|
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||||
|
|
||||||
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
||||||
update.mul_(group["lr"])
|
update.mul_(lr)
|
||||||
|
|
||||||
if use_first_moment:
|
if use_first_moment:
|
||||||
exp_avg = state["exp_avg"]
|
exp_avg = state["exp_avg"]
|
||||||
@@ -575,7 +575,7 @@ class Adafactor(Optimizer):
|
|||||||
update = exp_avg
|
update = exp_avg
|
||||||
|
|
||||||
if group["weight_decay"] != 0:
|
if group["weight_decay"] != 0:
|
||||||
p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
|
p_data_fp32.add_(-group["weight_decay"] * lr, p_data_fp32)
|
||||||
|
|
||||||
p_data_fp32.add_(-update)
|
p_data_fp32.add_(-update)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user