update jax version and re-enable some tests (#16254)
This commit is contained in:
2
setup.py
2
setup.py
@@ -112,7 +112,7 @@ _deps = [
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.8",
|
||||
"jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib>=0.1.65",
|
||||
"jieba",
|
||||
"nltk",
|
||||
|
||||
@@ -22,7 +22,7 @@ deps = {
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.8",
|
||||
"jax": "jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib": "jaxlib>=0.1.65",
|
||||
"jieba": "jieba",
|
||||
"nltk": "nltk",
|
||||
|
||||
@@ -691,10 +691,6 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
||||
@unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941")
|
||||
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
@@ -811,7 +807,3 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
||||
@unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941")
|
||||
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user