Flax testing should not run the full torch test suite (#10725)
* make flax tests pytorch independent * fix typo * finish * improve circle ci * fix return tensors * correct flax test * re-add sentencepiece * last tokenizer fixes * finish maybe now
This commit is contained in:
committed by
GitHub
parent
87d685b8a9
commit
9f8619c6aa
@@ -17,7 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
||||
from transformers.file_utils import cached_property, is_torch_available
|
||||
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -25,7 +25,12 @@ from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||
if is_torch_available():
|
||||
FRAMEWORK = "pt"
|
||||
elif is_tf_available():
|
||||
FRAMEWORK = "tf"
|
||||
else:
|
||||
FRAMEWORK = "jax"
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@@ -157,7 +162,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
|
||||
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
result = list(batch.input_ids.numpy()[0])
|
||||
|
||||
if FRAMEWORK != "jax":
|
||||
result = list(batch.input_ids.numpy()[0])
|
||||
else:
|
||||
result = list(batch.input_ids.tolist()[0])
|
||||
|
||||
self.assertListEqual(expected_src_tokens, result)
|
||||
|
||||
self.assertEqual((2, 9), batch.input_ids.shape)
|
||||
|
||||
Reference in New Issue
Block a user