unified tokenizer api and serialization + tests

This commit is contained in:
thomwolf
2019-07-09 10:25:18 +02:00
parent 3d5f291386
commit b19786985d
34 changed files with 824 additions and 755 deletions

View File

@@ -24,7 +24,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice)
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
@@ -267,7 +267,7 @@ class BertModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)

View File

@@ -413,7 +413,7 @@ class GPTModelTester(object):
def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.parent.assertIsNotNone(model)

View File

@@ -26,7 +26,7 @@ import pytest
import torch
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
@@ -185,7 +185,7 @@ class TransfoXLModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)

View File

@@ -20,12 +20,12 @@ import unittest
import logging
from pytorch_transformers import PretrainedConfig, PreTrainedModel
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, PretrainedConfig)

View File

@@ -21,7 +21,7 @@ import shutil
import pytest
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
@@ -251,7 +251,7 @@ class XLMModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)

View File

@@ -26,7 +26,7 @@ import pytest
import torch
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
@@ -279,7 +279,7 @@ class XLNetModelTest(unittest.TestCase):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)

View File

@@ -17,14 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
from io import open
import shutil
import pytest
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer,
WordpieceTokenizer,
_is_control, _is_punctuation,
_is_whitespace)
BertTokenizer,
WordpieceTokenizer,
_is_control, _is_punctuation,
_is_whitespace, VOCAB_FILES_NAMES)
from .tokenization_tests_commons import create_and_check_tokenizer_commons
@@ -33,13 +31,15 @@ class TokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ","
"##ing", ",", "low", "lowest",
]
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
vocab_directory = "/tmp/"
vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file)
create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory)
tokenizer = BertTokenizer(vocab_file)
@@ -80,7 +80,7 @@ class TokenizationTest(unittest.TestCase):
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab)
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
self.assertListEqual(tokenizer.tokenize(""), [])

View File

@@ -17,8 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import json
import tempfile
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons
@@ -28,31 +29,31 @@ class GPT2TokenizationTest(unittest.TestCase):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er",
"low", "lowest", "newer", "wider"]
"low", "lowest", "newer", "wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
special_tokens_map = {"unk_token": "<unk>"}
create_and_check_tokenizer_commons(self, GPT2Tokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
with tempfile.TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
text = "lower"
bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [13, 12, 16]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
text = "lower"
bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
os.remove(vocab_file)
os.remove(merges_file)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [13, 12, 17]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__':

View File

@@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import json
import shutil
import pytest
import tempfile
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons
@@ -32,31 +31,31 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
with tempfile.TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file)
os.remove(merges_file)
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__':

View File

@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import sys
from io import open
import tempfile
if sys.version_info[0] == 3:
unicode = str
@@ -28,22 +29,19 @@ else:
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs)
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
vocab_path="/tmp/"
output_files = tokenizer.save_vocabulary(vocab_path=vocab_path)
tokenizer = tokenizer.from_pretrained(vocab_path)
for f in output_files:
os.remove(f)
with tempfile.TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
tester.assertListEqual(before_tokens, after_tokens)
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs)
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
text = u"Munich and Berlin are nice cities"
filename = u"/tmp/tokenizer.bin"
@@ -58,8 +56,54 @@ def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs
tester.assertListEqual(subwords, subwords_loaded)
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)
tester.assertNotEqual(vocab_size, 0)
tester.assertEqual(vocab_size, all_size)
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)
tester.assertNotEqual(vocab_size_2, 0)
tester.assertEqual(vocab_size, vocab_size_2)
tester.assertEqual(added_toks, len(new_toks))
tester.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
tester.assertGreaterEqual(len(tokens), 4)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
'pad_token': "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
tester.assertNotEqual(vocab_size_3, 0)
tester.assertEqual(vocab_size, vocab_size_3)
tester.assertEqual(added_toks_2, len(new_toks_2))
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
tester.assertGreaterEqual(len(tokens), 6)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[0], tokens[1])
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokens[-3])
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs)
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
text = u"He is very happy, UNwant\u00E9d,running"
tokens = tokenizer.tokenize(text)
@@ -75,5 +119,6 @@ def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)

View File

@@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
from io import open
import shutil
import pytest
import tempfile
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons
@@ -28,22 +27,23 @@ class TransfoXLTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ","
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
"running", ",", "low", "l",
]
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
with tempfile.TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, vocab_file=vocab_file, lower_case=True)
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
os.remove(vocab_file)
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True)

View File

@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
import unittest
import six
from pytorch_transformers import PreTrainedTokenizer
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
@@ -27,8 +28,17 @@ class TokenizerUtilsTest(unittest.TestCase):
for model_name in s3_models[:1]:
tokenizer = tokenizer_class.from_pretrained(model_name)
self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, tokenizer_class)
self.assertIsInstance(tokenizer, PreTrainedTokenizer)
for special_tok in tokenizer.all_special_tokens:
if six.PY2:
self.assertIsInstance(special_tok, unicode)
else:
self.assertIsInstance(special_tok, str)
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
self.assertIsInstance(special_tok_id, int)
def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer)

View File

@@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import json
import shutil
import pytest
import tempfile
from pytorch_transformers.tokenization_xlm import XLMTokenizer
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons
@@ -31,31 +30,31 @@ class XLMTokenizationTest(unittest.TestCase):
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
create_and_check_tokenizer_commons(self, XLMTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
with tempfile.TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file)
os.remove(merges_file)
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
tokenizer = XLMTokenizer(vocab_file, merges_file)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__':

View File

@@ -16,10 +16,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import shutil
import pytest
import tempfile
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE, VOCAB_FILES_NAMES)
from.tokenization_tests_commons import create_and_check_tokenizer_commons
@@ -29,34 +28,37 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
class XLNetTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
create_and_check_tokenizer_commons(self, XLNetTokenizer, SAMPLE_VOCAB)
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
with tempfile.TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
602, 347, 347, 347, 3, 12, 66,
46, 72, 80, 6, 0, 4])
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
602, 347, 347, 347, 3, 12, 66,
46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)