TF version of the trainer (#4017)
* First commit to add a TF version of the trainer. * Make the TF trainer closer to what looks the PT trainer * Refactoring common code between the PT and TF trainer into an util file. * Some bugfix + better similarity with the PT trainer * Add missing class in transformers init * Bugfix over prediction + use classification report instead of simple metrics * Fix name error * Fix optimization tests + style * Apply style * Several bugfix for multi-gpu training * Apply style * Apply style * Add glue example for the TF trainer * Several bugix + address the reviews * Fix on the TF training args file * Add a debug mode * Bugfix in utils_ner.py when segment_ids is None * Apply style * Apply style * Add TPU strategy * Fix selection strategy
This commit is contained in:
@@ -36,14 +36,13 @@ class OptimizationFTest(unittest.TestCase):
|
||||
def testGradientAccumulatorDistributionStrategy(self):
|
||||
context._context = None
|
||||
ops.enable_eager_execution_internal()
|
||||
physical_devices = tf.config.experimental.list_physical_devices("CPU")
|
||||
tf.config.experimental.set_virtual_device_configuration(
|
||||
physical_devices[0],
|
||||
[tf.config.experimental.VirtualDeviceConfiguration(), tf.config.experimental.VirtualDeviceConfiguration()],
|
||||
)
|
||||
|
||||
devices = tf.config.experimental.list_logical_devices(device_type="CPU")
|
||||
strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices])
|
||||
physical_devices = tf.config.list_physical_devices("CPU")
|
||||
if len(physical_devices) == 1:
|
||||
tf.config.set_logical_device_configuration(
|
||||
physical_devices[0], [tf.config.LogicalDeviceConfiguration(), tf.config.LogicalDeviceConfiguration()]
|
||||
)
|
||||
devices = tf.config.list_logical_devices(device_type="CPU")
|
||||
strategy = tf.distribute.MirroredStrategy(devices=devices[:2])
|
||||
|
||||
with strategy.scope():
|
||||
accumulator = GradientAccumulator()
|
||||
@@ -55,13 +54,14 @@ class OptimizationFTest(unittest.TestCase):
|
||||
accumulator([gradient])
|
||||
|
||||
def apply_on_replica():
|
||||
optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])), 1.0)
|
||||
optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])))
|
||||
|
||||
@tf.function
|
||||
def accumulate(grad1, grad2):
|
||||
with strategy.scope():
|
||||
gradient_placeholder.values[0].assign(grad1)
|
||||
gradient_placeholder.values[1].assign(grad2)
|
||||
local_variables = strategy.experimental_local_results(gradient_placeholder)
|
||||
local_variables[0].assign(grad1)
|
||||
local_variables[1].assign(grad2)
|
||||
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
|
||||
|
||||
@tf.function
|
||||
@@ -69,15 +69,18 @@ class OptimizationFTest(unittest.TestCase):
|
||||
with strategy.scope():
|
||||
strategy.experimental_run_v2(apply_on_replica)
|
||||
|
||||
def _check_local_values(grad1, grad2):
|
||||
values = strategy.experimental_local_results(accumulator._gradients[0])
|
||||
self.assertListAlmostEqual(values[0].value(), grad1, tol=1e-2)
|
||||
self.assertListAlmostEqual(values[1].value(), grad2, tol=1e-2)
|
||||
|
||||
accumulate([1.0, 2.0], [-1.0, 1.0])
|
||||
accumulate([3.0, -1.0], [-1.0, -1.0])
|
||||
accumulate([-2.0, 2.0], [3.0, -2.0])
|
||||
self.assertEqual(accumulator.step, 3)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [2.0, 3.0], tol=1e-2)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [1.0, -2.0], tol=1e-2)
|
||||
_check_local_values([2.0, 3.0], [1.0, -2.0])
|
||||
apply_grad()
|
||||
self.assertListAlmostEqual(variable.value().numpy().tolist(), [4.0, 3.0], tol=1e-2)
|
||||
self.assertListAlmostEqual(variable.value(), [4.0, 3.0], tol=1e-2)
|
||||
accumulator.reset()
|
||||
self.assertEqual(accumulator.step, 0)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
|
||||
_check_local_values([0.0, 0.0], [0.0, 0.0])
|
||||
|
||||
Reference in New Issue
Block a user