Flax testing should not run the full torch test suite (#10725)
* make flax tests pytorch independent * fix typo * finish * improve circle ci * fix return tensors * correct flax test * re-add sentencepiece * last tokenizer fixes * finish maybe now
This commit is contained in:
committed by
GitHub
parent
87d685b8a9
commit
9f8619c6aa
@@ -35,6 +35,9 @@ def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import require_flax, require_torch
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -60,7 +60,6 @@ def random_attention_mask(shape, rng=None):
|
||||
return attn_mask
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
@@ -69,7 +68,7 @@ class FlaxModelTesterMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_pytorch(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -104,6 +103,7 @@ class FlaxModelTesterMixin:
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||
|
||||
@require_flax
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -121,6 +121,7 @@ class FlaxModelTesterMixin:
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||
|
||||
@require_flax
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -143,6 +144,7 @@ class FlaxModelTesterMixin:
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@require_flax
|
||||
def test_naming_convention(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_class_name = model_class.__name__
|
||||
|
||||
@@ -24,7 +24,13 @@ from collections import OrderedDict
|
||||
from itertools import takewhile
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, is_torch_available
|
||||
from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
is_pt_tf_cross_test,
|
||||
@@ -2283,7 +2289,12 @@ class TokenizerTesterMixin:
|
||||
"{} ({}, {})".format(tokenizer.__class__.__name__, pretrained_name, tokenizer.__class__.__name__)
|
||||
):
|
||||
|
||||
returned_tensor = "pt" if is_torch_available() else "tf"
|
||||
if is_torch_available():
|
||||
returned_tensor = "pt"
|
||||
elif is_tf_available():
|
||||
returned_tensor = "tf"
|
||||
else:
|
||||
returned_tensor = "jax"
|
||||
|
||||
if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
|
||||
return
|
||||
|
||||
@@ -21,7 +21,7 @@ from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
from transformers import BatchEncoding, MarianTokenizer
|
||||
from transformers.file_utils import is_sentencepiece_available, is_torch_available
|
||||
from transformers.file_utils import is_sentencepiece_available, is_tf_available, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece
|
||||
|
||||
|
||||
@@ -36,7 +36,13 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
|
||||
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
||||
zh_code = ">>zh<<"
|
||||
ORG_NAME = "Helsinki-NLP/"
|
||||
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||
|
||||
if is_torch_available():
|
||||
FRAMEWORK = "pt"
|
||||
elif is_tf_available():
|
||||
FRAMEWORK = "tf"
|
||||
else:
|
||||
FRAMEWORK = "jax"
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
||||
from transformers.file_utils import cached_property, is_torch_available
|
||||
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -25,7 +25,12 @@ from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||
if is_torch_available():
|
||||
FRAMEWORK = "pt"
|
||||
elif is_tf_available():
|
||||
FRAMEWORK = "tf"
|
||||
else:
|
||||
FRAMEWORK = "jax"
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@@ -157,7 +162,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
|
||||
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
result = list(batch.input_ids.numpy()[0])
|
||||
|
||||
if FRAMEWORK != "jax":
|
||||
result = list(batch.input_ids.numpy()[0])
|
||||
else:
|
||||
result = list(batch.input_ids.tolist()[0])
|
||||
|
||||
self.assertListEqual(expected_src_tokens, result)
|
||||
|
||||
self.assertEqual((2, 9), batch.input_ids.shape)
|
||||
|
||||
Reference in New Issue
Block a user