From 26f8b2cb1029cee7d355c5c25f253c08107883fe Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 25 Feb 2021 11:42:25 -0500 Subject: [PATCH] Make Barthez tokenizer tests a bit faster (#10399) * Make Barthez tokenizer tests a bit faster * Quality --- tests/test_tokenization_barthez.py | 3 ++- tests/test_tokenization_common.py | 2 +- tests/test_tokenization_mbart.py | 5 ++--- tests/test_tokenization_mbart50.py | 4 +--- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/test_tokenization_barthez.py b/tests/test_tokenization_barthez.py index afb8e48de3..8ff33ac2ad 100644 --- a/tests/test_tokenization_barthez.py +++ b/tests/test_tokenization_barthez.py @@ -33,8 +33,9 @@ class BarthezTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def setUp(self): super().setUp() - tokenizer = BarthezTokenizer.from_pretrained("moussaKam/mbarthez") + tokenizer = BarthezTokenizerFast.from_pretrained("moussaKam/mbarthez") tokenizer.save_pretrained(self.tmpdirname) + tokenizer.save_pretrained(self.tmpdirname, legacy_format=False) self.tokenizer = tokenizer @require_torch diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 761b2ee491..58b4eee3e8 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -238,7 +238,7 @@ class TokenizerTesterMixin: tokenizer = self.get_rust_tokenizer() for parameter_name, parameter in signature.parameters.items(): - if parameter.default != inspect.Parameter.empty: + if parameter.default != inspect.Parameter.empty and parameter_name != "tokenizer_file": self.assertIn(parameter_name, tokenizer.init_kwargs) def test_rust_and_python_full_tokenizers(self): diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index a67c75e1f4..83c2d33b6f 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import tempfile import unittest from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available -from transformers.file_utils import is_sentencepiece_available from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch from .test_tokenization_common import TokenizerTesterMixin -if is_sentencepiece_available(): - from .test_tokenization_xlm_roberta import SAMPLE_VOCAB +SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") if is_torch_available(): diff --git a/tests/test_tokenization_mbart50.py b/tests/test_tokenization_mbart50.py index f31d030c93..4c3561a907 100644 --- a/tests/test_tokenization_mbart50.py +++ b/tests/test_tokenization_mbart50.py @@ -17,14 +17,12 @@ import tempfile import unittest from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available -from transformers.file_utils import is_sentencepiece_available from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch from .test_tokenization_common import TokenizerTesterMixin -if is_sentencepiece_available(): - SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") +SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") if is_torch_available():