From 8d57c424e036b61f74ffa4c3e154d4d6d22fbff3 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Wed, 6 Apr 2022 15:33:32 +0200 Subject: [PATCH] [FlaxSpeechEncoderDecoderModel] More Rigorous PT-Flax Equivalence Tests (#16589) --- ...st_modeling_flax_speech_encoder_decoder.py | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 403255c4ce..873c09105b 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -413,28 +413,22 @@ class FlaxEncoderDecoderMixin: pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - pt_logits = pt_outputs.logits - pt_outputs = pt_outputs.to_tuple() - - fx_outputs = fx_model(**inputs_dict) - fx_logits = fx_outputs.logits - fx_outputs = fx_outputs.to_tuple() + pt_outputs = pt_model(**pt_inputs).to_tuple() + fx_outputs = fx_model(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2) + for fx_output, pt_output in zip(fx_outputs, pt_outputs): + self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) # PT -> Flax with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) - fx_outputs_loaded = fx_model_loaded(**inputs_dict) - fx_logits_loaded = fx_outputs_loaded.logits - fx_outputs_loaded = fx_outputs_loaded.to_tuple() - + fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2) + for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5) # Flax -> PT with tempfile.TemporaryDirectory() as tmpdirname: @@ -445,12 +439,11 @@ class FlaxEncoderDecoderMixin: pt_model_loaded.eval() with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs) - pt_logits_loaded = pt_outputs_loaded.logits - pt_outputs_loaded = pt_outputs_loaded.to_tuple() + pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") - self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2) + for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): + self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5) def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):