unified tokenizer api and serialization + tests
This commit is contained in:
@@ -24,7 +24,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||
BertForNextSentencePrediction, BertForPreTraining,
|
||||
BertForQuestionAnswering, BertForSequenceClassification,
|
||||
BertForTokenClassification, BertForMultipleChoice)
|
||||
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
||||
|
||||
@@ -267,7 +267,7 @@ class BertModelTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@@ -413,7 +413,7 @@ class GPTModelTester(object):
|
||||
|
||||
def create_and_check_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
|
||||
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.parent.assertIsNotNone(model)
|
||||
|
||||
@@ -26,7 +26,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
|
||||
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
|
||||
|
||||
@@ -185,7 +185,7 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@@ -20,12 +20,12 @@ import unittest
|
||||
import logging
|
||||
|
||||
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
||||
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
def test_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, PretrainedConfig)
|
||||
|
||||
@@ -21,7 +21,7 @@ import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
|
||||
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
||||
|
||||
@@ -251,7 +251,7 @@ class XLMModelTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@@ -26,7 +26,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
||||
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
|
||||
|
||||
@@ -279,7 +279,7 @@ class XLNetModelTest(unittest.TestCase):
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@@ -17,14 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
from io import open
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
|
||||
BertTokenizer,
|
||||
WordpieceTokenizer,
|
||||
_is_control, _is_punctuation,
|
||||
_is_whitespace)
|
||||
BertTokenizer,
|
||||
WordpieceTokenizer,
|
||||
_is_control, _is_punctuation,
|
||||
_is_whitespace, VOCAB_FILES_NAMES)
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@@ -33,13 +31,15 @@ class TokenizationTest(unittest.TestCase):
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ","
|
||||
"##ing", ",", "low", "lowest",
|
||||
]
|
||||
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_directory = "/tmp/"
|
||||
vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file)
|
||||
create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory)
|
||||
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
@@ -80,7 +80,7 @@ class TokenizationTest(unittest.TestCase):
|
||||
vocab = {}
|
||||
for (i, token) in enumerate(vocab_tokens):
|
||||
vocab[token] = i
|
||||
tokenizer = WordpieceTokenizer(vocab=vocab)
|
||||
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
|
||||
|
||||
self.assertListEqual(tokenizer.tokenize(""), [])
|
||||
|
||||
|
||||
@@ -17,8 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
|
||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@@ -28,31 +29,31 @@ class GPT2TokenizationTest(unittest.TestCase):
|
||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"lo", "low", "er",
|
||||
"low", "lowest", "newer", "wider"]
|
||||
"low", "lowest", "newer", "wider", "<unk>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
vocab_file = fp.name
|
||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [13, 12, 16]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [13, 12, 17]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import shutil
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer
|
||||
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@@ -32,31 +31,31 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"w</w>", "r</w>", "t</w>",
|
||||
"lo", "low", "er</w>",
|
||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
|
||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
vocab_file = fp.name
|
||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
import tempfile
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
unicode = str
|
||||
@@ -28,22 +29,19 @@ else:
|
||||
|
||||
|
||||
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
|
||||
vocab_path="/tmp/"
|
||||
output_files = tokenizer.save_vocabulary(vocab_path=vocab_path)
|
||||
tokenizer = tokenizer.from_pretrained(vocab_path)
|
||||
|
||||
for f in output_files:
|
||||
os.remove(f)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
||||
|
||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
tester.assertListEqual(before_tokens, after_tokens)
|
||||
|
||||
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
text = u"Munich and Berlin are nice cities"
|
||||
filename = u"/tmp/tokenizer.bin"
|
||||
@@ -58,8 +56,54 @@ def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs
|
||||
tester.assertListEqual(subwords, subwords_loaded)
|
||||
|
||||
|
||||
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
vocab_size = tokenizer.vocab_size
|
||||
all_size = len(tokenizer)
|
||||
|
||||
tester.assertNotEqual(vocab_size, 0)
|
||||
tester.assertEqual(vocab_size, all_size)
|
||||
|
||||
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
|
||||
added_toks = tokenizer.add_tokens(new_toks)
|
||||
vocab_size_2 = tokenizer.vocab_size
|
||||
all_size_2 = len(tokenizer)
|
||||
|
||||
tester.assertNotEqual(vocab_size_2, 0)
|
||||
tester.assertEqual(vocab_size, vocab_size_2)
|
||||
tester.assertEqual(added_toks, len(new_toks))
|
||||
tester.assertEqual(all_size_2, all_size + len(new_toks))
|
||||
|
||||
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
|
||||
tester.assertGreaterEqual(len(tokens), 4)
|
||||
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||
|
||||
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
|
||||
'pad_token': "<<<<<|||>|>>>>|>"}
|
||||
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||||
vocab_size_3 = tokenizer.vocab_size
|
||||
all_size_3 = len(tokenizer)
|
||||
|
||||
tester.assertNotEqual(vocab_size_3, 0)
|
||||
tester.assertEqual(vocab_size, vocab_size_3)
|
||||
tester.assertEqual(added_toks_2, len(new_toks_2))
|
||||
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||
|
||||
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
|
||||
|
||||
tester.assertGreaterEqual(len(tokens), 6)
|
||||
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
tester.assertGreater(tokens[0], tokens[1])
|
||||
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||
tester.assertGreater(tokens[-2], tokens[-3])
|
||||
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
|
||||
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
||||
|
||||
|
||||
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
text = u"He is very happy, UNwant\u00E9d,running"
|
||||
tokens = tokenizer.tokenize(text)
|
||||
@@ -75,5 +119,6 @@ def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs
|
||||
|
||||
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
|
||||
@@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
from io import open
|
||||
import shutil
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer
|
||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@@ -28,22 +27,23 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ","
|
||||
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
|
||||
"running", ",", "low", "l",
|
||||
]
|
||||
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
vocab_file = vocab_writer.name
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, vocab_file=vocab_file, lower_case=True)
|
||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||
|
||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||
os.remove(vocab_file)
|
||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
|
||||
def test_full_tokenizer_lower(self):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=True)
|
||||
|
||||
@@ -17,6 +17,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import six
|
||||
|
||||
from pytorch_transformers import PreTrainedTokenizer
|
||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
|
||||
@@ -27,8 +28,17 @@ class TokenizerUtilsTest(unittest.TestCase):
|
||||
for model_name in s3_models[:1]:
|
||||
tokenizer = tokenizer_class.from_pretrained(model_name)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
self.assertIsInstance(tokenizer, tokenizer_class)
|
||||
self.assertIsInstance(tokenizer, PreTrainedTokenizer)
|
||||
|
||||
for special_tok in tokenizer.all_special_tokens:
|
||||
if six.PY2:
|
||||
self.assertIsInstance(special_tok, unicode)
|
||||
else:
|
||||
self.assertIsInstance(special_tok, str)
|
||||
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
|
||||
self.assertIsInstance(special_tok_id, int)
|
||||
|
||||
def test_pretrained_tokenizers(self):
|
||||
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
|
||||
|
||||
|
||||
@@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import shutil
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from pytorch_transformers.tokenization_xlm import XLMTokenizer
|
||||
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@@ -31,31 +30,31 @@ class XLMTokenizationTest(unittest.TestCase):
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"w</w>", "r</w>", "t</w>",
|
||||
"lo", "low", "er</w>",
|
||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
|
||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
vocab_file = fp.name
|
||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLMTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
tokenizer = XLMTokenizer(vocab_file, merges_file)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -16,10 +16,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
|
||||
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE, VOCAB_FILES_NAMES)
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
@@ -29,34 +28,37 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
class XLNetTokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, SAMPLE_VOCAB)
|
||||
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
|
||||
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids, [8, 21, 84, 55, 24, 19, 7, 0,
|
||||
602, 347, 347, 347, 3, 12, 66,
|
||||
46, 72, 80, 6, 0, 4])
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in',
|
||||
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
|
||||
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
|
||||
u'<unk>', u'.'])
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
||||
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids, [8, 21, 84, 55, 24, 19, 7, 0,
|
||||
602, 347, 347, 347, 3, 12, 66,
|
||||
46, 72, 80, 6, 0, 4])
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in',
|
||||
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
|
||||
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
|
||||
u'<unk>', u'.'])
|
||||
|
||||
def test_tokenizer_lower(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
|
||||
|
||||
Reference in New Issue
Block a user