[Flax/JAX] Run jitted tests at every commit (#13090)

* up

* up

* up
This commit is contained in:
Patrick von Platen
2021-08-12 14:49:46 +02:00
committed by GitHub
parent 773d386041
commit 6900dded49
5 changed files with 27 additions and 5 deletions

View File

@@ -34,7 +34,6 @@ from transformers.testing_utils import (
is_pt_flax_cross_test,
is_staging_test,
require_flax,
slow,
)
from transformers.utils import logging
@@ -391,7 +390,6 @@ class FlaxModelTesterMixin:
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@slow
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()