[TF 2.2 compat] use tf.VariableAggregation.ONLY_FIRST_REPLICA (#4283)
* Fix the issue to properly run the accumulator with TF 2.2 * Apply style * Fix training_args_tf for TF 2.2 * Fix the TF training args when only one GPU is available * Remove the fixed version of TF in setup.py
This commit is contained in:
@@ -204,7 +204,10 @@ class GradientAccumulator(object):
|
||||
"""Number of accumulated steps."""
|
||||
if self._accum_steps is None:
|
||||
self._accum_steps = tf.Variable(
|
||||
tf.constant(0, dtype=tf.int64), trainable=False, synchronization=tf.VariableSynchronization.ON_READ,
|
||||
tf.constant(0, dtype=tf.int64),
|
||||
trainable=False,
|
||||
synchronization=tf.VariableSynchronization.ON_READ,
|
||||
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
|
||||
)
|
||||
|
||||
return self._accum_steps.value()
|
||||
@@ -223,7 +226,10 @@ class GradientAccumulator(object):
|
||||
self._gradients.extend(
|
||||
[
|
||||
tf.Variable(
|
||||
tf.zeros_like(gradient), trainable=False, synchronization=tf.VariableSynchronization.ON_READ,
|
||||
tf.zeros_like(gradient),
|
||||
trainable=False,
|
||||
synchronization=tf.VariableSynchronization.ON_READ,
|
||||
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
|
||||
)
|
||||
for gradient in gradients
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user