From 632ff3c39ebdbbad284b91bf9ccc5c7e365f0e03 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 17 Mar 2022 20:05:14 +0100 Subject: [PATCH] [FlaxSpeechEncoderDecoderModel] Skip from_encoder_decoder_pretrained (#16236) * skip the test * fix * fix skip --- .../test_modeling_flax_speech_encoder_decoder.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 4ceea974f3..d35c0065f6 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -691,6 +691,10 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2) + @unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941") + def test_encoder_decoder_model_from_encoder_decoder_pretrained(self): + pass + @require_flax class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): @@ -807,3 +811,7 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2) + + @unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941") + def test_encoder_decoder_model_from_encoder_decoder_pretrained(self): + pass