Feature/fix slow test in mluke (#14749)

* make MLukeTokenizerTest fast

* make LukeTokenizerTest fast

* add entry to _toctree.yaml
This commit is contained in:
Ryokan RI
2021-12-22 20:35:59 +09:00
committed by GitHub
parent c94c1b8967
commit 824fd44fc3
6 changed files with 92 additions and 44 deletions

1
tests/fixtures/test_entity_vocab.json vendored Normal file
View File

@@ -0,0 +1 @@
{"[MASK]": 0, "[UNK]": 1, "[PAD]": 2, "DUMMY": 3, "DUMMY2": 4, "[MASK2]": 5}

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from typing import Tuple
@@ -23,6 +23,11 @@ from transformers.testing_utils import require_torch, slow
from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
SAMPLE_MERGE_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/merges.txt")
SAMPLE_ENTITY_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_entity_vocab.json")
class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = LukeTokenizer
test_rust_tokenizer = False
@@ -35,7 +40,15 @@ class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, task=None, **kwargs):
kwargs.update(self.special_tokens_map)
return self.tokenizer_class.from_pretrained("studio-ousia/luke-base", task=task, **kwargs)
tokenizer = LukeTokenizer(
vocab_file=SAMPLE_VOCAB,
merges_file=SAMPLE_MERGE_FILE,
entity_vocab_file=SAMPLE_ENTITY_VOCAB,
task=task,
**kwargs,
)
tokenizer.sanitize_special_tokens()
return tokenizer
def get_input_output_texts(self, tokenizer):
input_text = "lower newer"
@@ -43,25 +56,16 @@ class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/luke-base")
tokenizer = self.get_tokenizer()
text = "lower newer"
bpe_tokens = ["lower", "\u0120newer"]
bpe_tokens = ["l", "o", "w", "er", "Ġ", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [29668, 13964, 3]
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def luke_dict_integration_testing(self):
tokenizer = self.get_tokenizer()
self.assertListEqual(tokenizer.encode("Hello world!", add_special_tokens=False), [0, 31414, 232, 328, 2])
self.assertListEqual(
tokenizer.encode("Hello world! cécé herlolip 418", add_special_tokens=False),
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2],
)
@slow
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/luke-large")
@@ -235,6 +239,7 @@ class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer(sentence, entity_spans=[0, 0, 0])
@slow
@require_torch
class LukeTokenizerIntegrationTests(unittest.TestCase):
tokenizer_class = LukeTokenizer

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import os
import unittest
from typing import Tuple
@@ -23,7 +24,10 @@ from transformers.testing_utils import require_torch, slow
from .test_tokenization_common import TokenizerTesterMixin
@slow
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
SAMPLE_ENTITY_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_entity_vocab.json")
class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MLukeTokenizer
test_rust_tokenizer = False
@@ -37,7 +41,9 @@ class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, task=None, **kwargs):
kwargs.update(self.special_tokens_map)
kwargs.update({"task": task})
return self.tokenizer_class.from_pretrained("studio-ousia/mluke-base", **kwargs)
tokenizer = MLukeTokenizer(vocab_file=SAMPLE_VOCAB, entity_vocab_file=SAMPLE_ENTITY_VOCAB, **kwargs)
tokenizer.sanitize_special_tokens()
return tokenizer
def get_input_output_texts(self, tokenizer):
input_text = "lower newer"
@@ -45,14 +51,14 @@ class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/mluke-base")
tokenizer = self.get_tokenizer()
text = "lower newer"
spm_tokens = ["▁lower", "▁new", "er"]
spm_tokens = ["▁l", "ow", "er", "▁new", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, spm_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_spm_tokens = [92319, 3525, 56, 3]
input_spm_tokens = [149, 116, 40, 410, 40] + [3]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_spm_tokens)
def mluke_dict_integration_testing(self):
@@ -140,7 +146,7 @@ class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = self.get_tokenizer()
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and Afghanistan."
entities = ["en:ISO 639-3"]
entities = ["DUMMY"]
spans = [(0, 9)]
with self.assertRaises(ValueError):