[FlaxSpeechEncoderDecoderModel] Skip from_encoder_decoder_pretrained (#16236)
* skip the test * fix * fix skip
This commit is contained in:
@@ -691,6 +691,10 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
|||||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||||
|
|
||||||
|
@unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941")
|
||||||
|
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||||
@@ -807,3 +811,7 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||||
|
|
||||||
|
@unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941")
|
||||||
|
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user