update jax version and re-enable some tests (#16254)

This commit is contained in:
Suraj Patil
2022-03-18 16:45:39 +01:00
committed by GitHub
parent 5709a20416
commit b25b92ac4f
3 changed files with 2 additions and 10 deletions

View File

@@ -691,10 +691,6 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
@unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941")
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
pass
@require_flax
class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
@@ -811,7 +807,3 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
@unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941")
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
pass