update jax version and re-enable some tests (#16254)
This commit is contained in:
2
setup.py
2
setup.py
@@ -112,7 +112,7 @@ _deps = [
|
|||||||
"importlib_metadata",
|
"importlib_metadata",
|
||||||
"ipadic>=1.0.0,<2.0",
|
"ipadic>=1.0.0,<2.0",
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
"jax>=0.2.8",
|
"jax>=0.2.8,!=0.3.2",
|
||||||
"jaxlib>=0.1.65",
|
"jaxlib>=0.1.65",
|
||||||
"jieba",
|
"jieba",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ deps = {
|
|||||||
"importlib_metadata": "importlib_metadata",
|
"importlib_metadata": "importlib_metadata",
|
||||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
"jax": "jax>=0.2.8",
|
"jax": "jax>=0.2.8,!=0.3.2",
|
||||||
"jaxlib": "jaxlib>=0.1.65",
|
"jaxlib": "jaxlib>=0.1.65",
|
||||||
"jieba": "jieba",
|
"jieba": "jieba",
|
||||||
"nltk": "nltk",
|
"nltk": "nltk",
|
||||||
|
|||||||
@@ -691,10 +691,6 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
|||||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||||
|
|
||||||
@unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941")
|
|
||||||
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||||
@@ -811,7 +807,3 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||||
|
|
||||||
@unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941")
|
|
||||||
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
|
|
||||||
pass
|
|
||||||
|
|||||||
Reference in New Issue
Block a user