Flax testing should not run the full torch test suite (#10725)

* make flax tests pytorch independent

* fix typo

* finish

* improve circle ci

* fix return tensors

* correct flax test

* re-add sentencepiece

* last tokenizer fixes

* finish maybe now
This commit is contained in:
Patrick von Platen
2021-03-16 08:05:37 +03:00
committed by GitHub
parent 87d685b8a9
commit 9f8619c6aa
9 changed files with 94 additions and 14 deletions

View File

@@ -19,7 +19,7 @@ import numpy as np
import transformers
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import require_flax, require_torch
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
if is_flax_available():
@@ -60,7 +60,6 @@ def random_attention_mask(shape, rng=None):
return attn_mask
@require_flax
class FlaxModelTesterMixin:
model_tester = None
all_model_classes = ()
@@ -69,7 +68,7 @@ class FlaxModelTesterMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@require_torch
@is_pt_flax_cross_test
def test_equivalence_flax_pytorch(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -104,6 +103,7 @@ class FlaxModelTesterMixin:
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
@require_flax
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -121,6 +121,7 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3)
@require_flax
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -143,6 +144,7 @@ class FlaxModelTesterMixin:
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@require_flax
def test_naming_convention(self):
for model_class in self.all_model_classes:
model_class_name = model_class.__name__