Fix tf2.4 (#9120)
* Fix tests for TF 2.4 * Remove <2.4 limitation * Add version condition * Update tests/test_optimization_tf.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_optimization_tf.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_optimization_tf.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from packaging import version
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_tf
|
||||
|
||||
@@ -76,12 +78,18 @@ class OptimizationFTest(unittest.TestCase):
|
||||
local_variables = strategy.experimental_local_results(gradient_placeholder)
|
||||
local_variables[0].assign(grad1)
|
||||
local_variables[1].assign(grad2)
|
||||
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
|
||||
if version.parse(tf.version.VERSION) >= version.parse("2.2"):
|
||||
strategy.run(accumulate_on_replica, args=(gradient_placeholder,))
|
||||
else:
|
||||
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
|
||||
|
||||
@tf.function
|
||||
def apply_grad():
|
||||
with strategy.scope():
|
||||
strategy.experimental_run_v2(apply_on_replica)
|
||||
if version.parse(tf.version.VERSION) >= version.parse("2.2"):
|
||||
strategy.run(apply_on_replica)
|
||||
else:
|
||||
strategy.experimental_run_v2(apply_on_replica)
|
||||
|
||||
def _check_local_values(grad1, grad2):
|
||||
values = strategy.experimental_local_results(accumulator._gradients[0])
|
||||
|
||||
Reference in New Issue
Block a user