From b25b92ac4ffb8b66fb517f8888cbdc37075a9fd7 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 18 Mar 2022 16:45:39 +0100 Subject: [PATCH] update jax version and re-enable some tests (#16254) --- setup.py | 2 +- src/transformers/dependency_versions_table.py | 2 +- .../test_modeling_flax_speech_encoder_decoder.py | 8 -------- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index 343bea3acf..d398d59618 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,7 @@ _deps = [ "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", - "jax>=0.2.8", + "jax>=0.2.8,!=0.3.2", "jaxlib>=0.1.65", "jieba", "nltk", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 1ffaa15036..ee9b0fe28e 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -22,7 +22,7 @@ deps = { "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", - "jax": "jax>=0.2.8", + "jax": "jax>=0.2.8,!=0.3.2", "jaxlib": "jaxlib>=0.1.65", "jieba": "jieba", "nltk": "nltk", diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index d35c0065f6..4ceea974f3 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -691,10 +691,6 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2) - @unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941") - def test_encoder_decoder_model_from_encoder_decoder_pretrained(self): - pass - @require_flax class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): @@ -811,7 +807,3 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2) - - @unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941") - def test_encoder_decoder_model_from_encoder_decoder_pretrained(self): - pass