Allow for None gradients in GradientAccumulator. (#4372)

This commit is contained in:
Jared T Nielsen
2020-05-15 07:52:00 -06:00
committed by GitHub
parent edf9ac11d4
commit 34706ba050

View File

@@ -217,7 +217,7 @@ class GradientAccumulator(object):
"""The accumulated gradients on the current replica.""" """The accumulated gradients on the current replica."""
if not self._gradients: if not self._gradients:
raise ValueError("The accumulator should be called first to initialize the gradients") raise ValueError("The accumulator should be called first to initialize the gradients")
return list(gradient.value() for gradient in self._gradients) return list(gradient.value() if gradient is not None else gradient for gradient in self._gradients)
def __call__(self, gradients): def __call__(self, gradients):
"""Accumulates :obj:`gradients` on the current replica.""" """Accumulates :obj:`gradients` on the current replica."""
@@ -231,6 +231,8 @@ class GradientAccumulator(object):
synchronization=tf.VariableSynchronization.ON_READ, synchronization=tf.VariableSynchronization.ON_READ,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
) )
if gradient is not None
else gradient
for gradient in gradients for gradient in gradients
] ]
) )
@@ -238,7 +240,8 @@ class GradientAccumulator(object):
raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients))) raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
for accum_gradient, gradient in zip(self._gradients, gradients): for accum_gradient, gradient in zip(self._gradients, gradients):
accum_gradient.assign_add(gradient) if accum_gradient is not None and gradient is not None:
accum_gradient.assign_add(gradient)
self._accum_steps.assign_add(1) self._accum_steps.assign_add(1)
@@ -248,4 +251,5 @@ class GradientAccumulator(object):
return return
self._accum_steps.assign(0) self._accum_steps.assign(0)
for gradient in self._gradients: for gradient in self._gradients:
gradient.assign(tf.zeros_like(gradient)) if gradient is not None:
gradient.assign(tf.zeros_like(gradient))