Fix Bug in Flax-Speech-Encoder-Decoder Test (#16041)
* Fix Bug in Flax-Speech-Encoder-Decoder Test * change thresholds for CPU precision
This commit is contained in:
@@ -303,14 +303,12 @@ class FlaxEncoderDecoderMixin:
|
|||||||
inputs,
|
inputs,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
decoder_attention_mask,
|
|
||||||
freeze_feature_encoder: bool = False,
|
freeze_feature_encoder: bool = False,
|
||||||
):
|
):
|
||||||
outputs_enc_dec = enc_dec_model(
|
outputs_enc_dec = enc_dec_model(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
|
||||||
freeze_feature_encoder=freeze_feature_encoder,
|
freeze_feature_encoder=freeze_feature_encoder,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
@@ -323,13 +321,11 @@ class FlaxEncoderDecoderMixin:
|
|||||||
grad_fn = jax.value_and_grad(compute_loss)
|
grad_fn = jax.value_and_grad(compute_loss)
|
||||||
|
|
||||||
# compute the loss and gradients for the unfrozen model
|
# compute the loss and gradients for the unfrozen model
|
||||||
loss, grads = grad_fn(
|
loss, grads = grad_fn(params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=False)
|
||||||
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# compare to the loss and gradients for the frozen model
|
# compare to the loss and gradients for the frozen model
|
||||||
loss_frozen, grads_frozen = grad_fn(
|
loss_frozen, grads_frozen = grad_fn(
|
||||||
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=True
|
params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=True
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assert_almost_equals(loss, loss_frozen, 1e-5)
|
self.assert_almost_equals(loss, loss_frozen, 1e-5)
|
||||||
@@ -348,14 +344,14 @@ class FlaxEncoderDecoderMixin:
|
|||||||
feature_extractor_grads, feature_extractor_grads_frozen
|
feature_extractor_grads, feature_extractor_grads_frozen
|
||||||
):
|
):
|
||||||
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
|
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
|
||||||
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-8)
|
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-10)
|
||||||
|
|
||||||
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
|
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
|
||||||
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
|
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
|
||||||
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
|
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
|
||||||
|
|
||||||
for grad, grad_frozen in zip(grads, grads_frozen):
|
for grad, grad_frozen in zip(grads, grads_frozen):
|
||||||
self.assert_almost_equals(grad, grad_frozen, 1e-8)
|
self.assert_almost_equals(grad, grad_frozen, 1e-10)
|
||||||
|
|
||||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user