clean up optimizer from unused functions
This commit is contained in:
@@ -90,27 +90,6 @@ class BERTAdam(Optimizer):
|
|||||||
lr.append(lr_scheduled)
|
lr.append(lr_scheduled)
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
""" Move the optimizer state to a specified device"""
|
|
||||||
for state in self.state.values():
|
|
||||||
state['exp_avg'].to(device)
|
|
||||||
state['exp_avg_sq'].to(device)
|
|
||||||
|
|
||||||
def initialize_step(self, initial_step):
|
|
||||||
"""Initialize state with a defined step (but we don't have stored averaged).
|
|
||||||
Arguments:
|
|
||||||
initial_step (int): Initial step number.
|
|
||||||
"""
|
|
||||||
for group in self.param_groups:
|
|
||||||
for p in group['params']:
|
|
||||||
state = self.state[p]
|
|
||||||
# State initialization
|
|
||||||
state['step'] = initial_step
|
|
||||||
# Exponential moving average of gradient values
|
|
||||||
state['exp_avg'] = torch.zeros_like(p.data)
|
|
||||||
# Exponential moving average of squared gradient values
|
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
|
||||||
|
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user