diff --git a/optimization.py b/optimization.py index e04f01b1d2..bce0a9bf2a 100644 --- a/optimization.py +++ b/optimization.py @@ -90,27 +90,6 @@ class BERTAdam(Optimizer): lr.append(lr_scheduled) return lr - def to(self, device): - """ Move the optimizer state to a specified device""" - for state in self.state.values(): - state['exp_avg'].to(device) - state['exp_avg_sq'].to(device) - - def initialize_step(self, initial_step): - """Initialize state with a defined step (but we don't have stored averaged). - Arguments: - initial_step (int): Initial step number. - """ - for group in self.param_groups: - for p in group['params']: - state = self.state[p] - # State initialization - state['step'] = initial_step - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) - def step(self, closure=None): """Performs a single optimization step.