[Flax/JAX] Run jitted tests at every commit (#13090)

* up

* up

* up
This commit is contained in:
Patrick von Platen
2021-08-12 14:49:46 +02:00
committed by GitHub
parent 773d386041
commit 6900dded49
5 changed files with 27 additions and 5 deletions

View File

@@ -187,7 +187,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
expected_arg_names = ["input_values", "attention_mask"]
self.assertListEqual(arg_names[:2], expected_arg_names)
@slow
# overwrite because of `input_values`
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()