Feature/fix slow test in mluke (#14749)
* make MLukeTokenizerTest fast * make LukeTokenizerTest fast * add entry to _toctree.yaml
This commit is contained in:
1
tests/fixtures/test_entity_vocab.json
vendored
Normal file
1
tests/fixtures/test_entity_vocab.json
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"[MASK]": 0, "[UNK]": 1, "[PAD]": 2, "DUMMY": 3, "DUMMY2": 4, "[MASK2]": 5}
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
@@ -23,6 +23,11 @@ from transformers.testing_utils import require_torch, slow
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
|
||||
SAMPLE_MERGE_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/merges.txt")
|
||||
SAMPLE_ENTITY_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_entity_vocab.json")
|
||||
|
||||
|
||||
class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = LukeTokenizer
|
||||
test_rust_tokenizer = False
|
||||
@@ -35,7 +40,15 @@ class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
def get_tokenizer(self, task=None, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return self.tokenizer_class.from_pretrained("studio-ousia/luke-base", task=task, **kwargs)
|
||||
tokenizer = LukeTokenizer(
|
||||
vocab_file=SAMPLE_VOCAB,
|
||||
merges_file=SAMPLE_MERGE_FILE,
|
||||
entity_vocab_file=SAMPLE_ENTITY_VOCAB,
|
||||
task=task,
|
||||
**kwargs,
|
||||
)
|
||||
tokenizer.sanitize_special_tokens()
|
||||
return tokenizer
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
@@ -43,25 +56,16 @@ class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/luke-base")
|
||||
tokenizer = self.get_tokenizer()
|
||||
text = "lower newer"
|
||||
bpe_tokens = ["lower", "\u0120newer"]
|
||||
bpe_tokens = ["l", "o", "w", "er", "Ġ", "n", "e", "w", "er"]
|
||||
tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [29668, 13964, 3]
|
||||
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def luke_dict_integration_testing(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
self.assertListEqual(tokenizer.encode("Hello world!", add_special_tokens=False), [0, 31414, 232, 328, 2])
|
||||
self.assertListEqual(
|
||||
tokenizer.encode("Hello world! cécé herlolip 418", add_special_tokens=False),
|
||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/luke-large")
|
||||
@@ -235,6 +239,7 @@ class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer(sentence, entity_spans=[0, 0, 0])
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class LukeTokenizerIntegrationTests(unittest.TestCase):
|
||||
tokenizer_class = LukeTokenizer
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
@@ -23,7 +24,10 @@ from transformers.testing_utils import require_torch, slow
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
@slow
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
SAMPLE_ENTITY_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_entity_vocab.json")
|
||||
|
||||
|
||||
class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = MLukeTokenizer
|
||||
test_rust_tokenizer = False
|
||||
@@ -37,7 +41,9 @@ class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def get_tokenizer(self, task=None, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
kwargs.update({"task": task})
|
||||
return self.tokenizer_class.from_pretrained("studio-ousia/mluke-base", **kwargs)
|
||||
tokenizer = MLukeTokenizer(vocab_file=SAMPLE_VOCAB, entity_vocab_file=SAMPLE_ENTITY_VOCAB, **kwargs)
|
||||
tokenizer.sanitize_special_tokens()
|
||||
return tokenizer
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
@@ -45,14 +51,14 @@ class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/mluke-base")
|
||||
tokenizer = self.get_tokenizer()
|
||||
text = "lower newer"
|
||||
spm_tokens = ["▁lower", "▁new", "er"]
|
||||
spm_tokens = ["▁l", "ow", "er", "▁new", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, spm_tokens)
|
||||
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_spm_tokens = [92319, 3525, 56, 3]
|
||||
input_spm_tokens = [149, 116, 40, 410, 40] + [3]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_spm_tokens)
|
||||
|
||||
def mluke_dict_integration_testing(self):
|
||||
@@ -140,7 +146,7 @@ class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and Afghanistan."
|
||||
entities = ["en:ISO 639-3"]
|
||||
entities = ["DUMMY"]
|
||||
spans = [(0, 9)]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
|
||||
Reference in New Issue
Block a user