[FlaxSpeechEncoderDecoder] Fix feature extractor gradient test (#16407)
This commit is contained in:
@@ -389,14 +389,14 @@ class FlaxEncoderDecoderMixin:
|
|||||||
feature_extractor_grads, feature_extractor_grads_frozen
|
feature_extractor_grads, feature_extractor_grads_frozen
|
||||||
):
|
):
|
||||||
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
|
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
|
||||||
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-10)
|
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-5)
|
||||||
|
|
||||||
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
|
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
|
||||||
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
|
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
|
||||||
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
|
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
|
||||||
|
|
||||||
for grad, grad_frozen in zip(grads, grads_frozen):
|
for grad, grad_frozen in zip(grads, grads_frozen):
|
||||||
self.assert_almost_equals(grad, grad_frozen, 1e-10)
|
self.assert_almost_equals(grad, grad_frozen, 1e-5)
|
||||||
|
|
||||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||||
|
|
||||||
@@ -507,7 +507,7 @@ class FlaxEncoderDecoderMixin:
|
|||||||
self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).")
|
self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).")
|
||||||
|
|
||||||
def assert_difference(self, a: np.ndarray, b: np.ndarray, tol: float):
|
def assert_difference(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||||
diff = np.abs((a - b)).min()
|
diff = np.abs((a - b)).max()
|
||||||
self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).")
|
self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).")
|
||||||
|
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
|
|||||||
Reference in New Issue
Block a user