Allow for None gradients in GradientAccumulator. (#4372)
This commit is contained in:
@@ -217,7 +217,7 @@ class GradientAccumulator(object):
|
|||||||
"""The accumulated gradients on the current replica."""
|
"""The accumulated gradients on the current replica."""
|
||||||
if not self._gradients:
|
if not self._gradients:
|
||||||
raise ValueError("The accumulator should be called first to initialize the gradients")
|
raise ValueError("The accumulator should be called first to initialize the gradients")
|
||||||
return list(gradient.value() for gradient in self._gradients)
|
return list(gradient.value() if gradient is not None else gradient for gradient in self._gradients)
|
||||||
|
|
||||||
def __call__(self, gradients):
|
def __call__(self, gradients):
|
||||||
"""Accumulates :obj:`gradients` on the current replica."""
|
"""Accumulates :obj:`gradients` on the current replica."""
|
||||||
@@ -231,6 +231,8 @@ class GradientAccumulator(object):
|
|||||||
synchronization=tf.VariableSynchronization.ON_READ,
|
synchronization=tf.VariableSynchronization.ON_READ,
|
||||||
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
|
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
|
||||||
)
|
)
|
||||||
|
if gradient is not None
|
||||||
|
else gradient
|
||||||
for gradient in gradients
|
for gradient in gradients
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -238,6 +240,7 @@ class GradientAccumulator(object):
|
|||||||
raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
|
raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
|
||||||
|
|
||||||
for accum_gradient, gradient in zip(self._gradients, gradients):
|
for accum_gradient, gradient in zip(self._gradients, gradients):
|
||||||
|
if accum_gradient is not None and gradient is not None:
|
||||||
accum_gradient.assign_add(gradient)
|
accum_gradient.assign_add(gradient)
|
||||||
|
|
||||||
self._accum_steps.assign_add(1)
|
self._accum_steps.assign_add(1)
|
||||||
@@ -248,4 +251,5 @@ class GradientAccumulator(object):
|
|||||||
return
|
return
|
||||||
self._accum_steps.assign(0)
|
self._accum_steps.assign(0)
|
||||||
for gradient in self._gradients:
|
for gradient in self._gradients:
|
||||||
|
if gradient is not None:
|
||||||
gradient.assign(tf.zeros_like(gradient))
|
gradient.assign(tf.zeros_like(gradient))
|
||||||
|
|||||||
Reference in New Issue
Block a user