Flax testing should not run the full torch test suite (#10725)

* make flax tests pytorch independent

* fix typo

* finish

* improve circle ci

* fix return tensors

* correct flax test

* re-add sentencepiece

* last tokenizer fixes

* finish maybe now
This commit is contained in:
Patrick von Platen
2021-03-16 08:05:37 +03:00
committed by GitHub
parent 87d685b8a9
commit 9f8619c6aa
9 changed files with 94 additions and 14 deletions

View File

@@ -35,6 +35,9 @@ def pytest_configure(config):
config.addinivalue_line(
"markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested"
)
config.addinivalue_line(
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
)
def pytest_addoption(parser):

View File

@@ -19,7 +19,7 @@ import numpy as np
import transformers
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import require_flax, require_torch
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
if is_flax_available():
@@ -60,7 +60,6 @@ def random_attention_mask(shape, rng=None):
return attn_mask
@require_flax
class FlaxModelTesterMixin:
model_tester = None
all_model_classes = ()
@@ -69,7 +68,7 @@ class FlaxModelTesterMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@require_torch
@is_pt_flax_cross_test
def test_equivalence_flax_pytorch(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -104,6 +103,7 @@ class FlaxModelTesterMixin:
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
@require_flax
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -121,6 +121,7 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3)
@require_flax
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -143,6 +144,7 @@ class FlaxModelTesterMixin:
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@require_flax
def test_naming_convention(self):
for model_class in self.all_model_classes:
model_class_name = model_class.__name__

View File

@@ -24,7 +24,13 @@ from collections import OrderedDict
from itertools import takewhile
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, is_torch_available
from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
is_tf_available,
is_torch_available,
)
from transformers.testing_utils import (
get_tests_dir,
is_pt_tf_cross_test,
@@ -2283,7 +2289,12 @@ class TokenizerTesterMixin:
"{} ({}, {})".format(tokenizer.__class__.__name__, pretrained_name, tokenizer.__class__.__name__)
):
returned_tensor = "pt" if is_torch_available() else "tf"
if is_torch_available():
returned_tensor = "pt"
elif is_tf_available():
returned_tensor = "tf"
else:
returned_tensor = "jax"
if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
return

View File

@@ -21,7 +21,7 @@ from pathlib import Path
from shutil import copyfile
from transformers import BatchEncoding, MarianTokenizer
from transformers.file_utils import is_sentencepiece_available, is_torch_available
from transformers.file_utils import is_sentencepiece_available, is_tf_available, is_torch_available
from transformers.testing_utils import require_sentencepiece
@@ -36,7 +36,13 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
zh_code = ">>zh<<"
ORG_NAME = "Helsinki-NLP/"
FRAMEWORK = "pt" if is_torch_available() else "tf"
if is_torch_available():
FRAMEWORK = "pt"
elif is_tf_available():
FRAMEWORK = "tf"
else:
FRAMEWORK = "jax"
@require_sentencepiece

View File

@@ -17,7 +17,7 @@
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
from transformers.file_utils import cached_property, is_torch_available
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin
@@ -25,7 +25,12 @@ from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if is_torch_available() else "tf"
if is_torch_available():
FRAMEWORK = "pt"
elif is_tf_available():
FRAMEWORK = "tf"
else:
FRAMEWORK = "jax"
@require_sentencepiece
@@ -157,7 +162,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0])
if FRAMEWORK != "jax":
result = list(batch.input_ids.numpy()[0])
else:
result = list(batch.input_ids.tolist()[0])
self.assertListEqual(expected_src_tokens, result)
self.assertEqual((2, 9), batch.input_ids.shape)