Flax testing should not run the full torch test suite (#10725)
* make flax tests pytorch independent * fix typo * finish * improve circle ci * fix return tensors * correct flax test * re-add sentencepiece * last tokenizer fixes * finish maybe now
This commit is contained in:
committed by
GitHub
parent
87d685b8a9
commit
9f8619c6aa
@@ -19,7 +19,7 @@ import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import require_flax, require_torch
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -60,7 +60,6 @@ def random_attention_mask(shape, rng=None):
|
||||
return attn_mask
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
@@ -69,7 +68,7 @@ class FlaxModelTesterMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_pytorch(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -104,6 +103,7 @@ class FlaxModelTesterMixin:
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||
|
||||
@require_flax
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -121,6 +121,7 @@ class FlaxModelTesterMixin:
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||
|
||||
@require_flax
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -143,6 +144,7 @@ class FlaxModelTesterMixin:
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@require_flax
|
||||
def test_naming_convention(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_class_name = model_class.__name__
|
||||
|
||||
Reference in New Issue
Block a user