fix python 2 tests
This commit is contained in:
@@ -24,7 +24,7 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer,
|
|||||||
_is_control, _is_punctuation,
|
_is_control, _is_punctuation,
|
||||||
_is_whitespace, VOCAB_FILES_NAMES)
|
_is_whitespace, VOCAB_FILES_NAMES)
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||||
|
|
||||||
class TokenizationTest(unittest.TestCase):
|
class TokenizationTest(unittest.TestCase):
|
||||||
|
|
||||||
@@ -33,13 +33,12 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
"##ing", ",", "low", "lowest",
|
"##ing", ",", "low", "lowest",
|
||||||
]
|
]
|
||||||
vocab_directory = "/tmp/"
|
with TemporaryDirectory() as tmpdirname:
|
||||||
vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
vocab_file = vocab_writer.name
|
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory)
|
create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname)
|
||||||
|
|
||||||
tokenizer = BertTokenizer(vocab_file)
|
tokenizer = BertTokenizer(vocab_file)
|
||||||
|
|
||||||
@@ -47,8 +46,6 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||||
|
|
||||||
os.remove(vocab_file)
|
|
||||||
|
|
||||||
def test_chinese(self):
|
def test_chinese(self):
|
||||||
tokenizer = BasicTokenizer()
|
tokenizer = BasicTokenizer()
|
||||||
|
|
||||||
|
|||||||
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||||
|
|
||||||
class GPT2TokenizationTest(unittest.TestCase):
|
class GPT2TokenizationTest(unittest.TestCase):
|
||||||
|
|
||||||
@@ -34,7 +33,7 @@ class GPT2TokenizationTest(unittest.TestCase):
|
|||||||
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
||||||
special_tokens_map = {"unk_token": "<unk>"}
|
special_tokens_map = {"unk_token": "<unk>"}
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with TemporaryDirectory() as tmpdirname:
|
||||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
with open(vocab_file, "w") as fp:
|
with open(vocab_file, "w") as fp:
|
||||||
|
|||||||
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTTokenizationTest(unittest.TestCase):
|
class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||||
@@ -35,7 +34,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
|||||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with TemporaryDirectory() as tmpdirname:
|
||||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
with open(vocab_file, "w") as fp:
|
with open(vocab_file, "w") as fp:
|
||||||
|
|||||||
@@ -14,18 +14,25 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import shutil
|
||||||
if sys.version_info[0] == 3:
|
|
||||||
unicode = str
|
|
||||||
|
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
import cPickle as pickle
|
import cPickle as pickle
|
||||||
|
|
||||||
|
class TemporaryDirectory(object):
|
||||||
|
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
|
||||||
|
def __enter__(self):
|
||||||
|
self.name = tempfile.mkdtemp()
|
||||||
|
return self.name
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
shutil.rmtree(self.name)
|
||||||
else:
|
else:
|
||||||
import pickle
|
import pickle
|
||||||
|
TemporaryDirectory = tempfile.TemporaryDirectory
|
||||||
|
unicode = str
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||||
@@ -33,7 +40,7 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, *
|
|||||||
|
|
||||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with TemporaryDirectory() as tmpdirname:
|
||||||
tokenizer.save_pretrained(tmpdirname)
|
tokenizer.save_pretrained(tmpdirname)
|
||||||
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
|||||||
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from io import open
|
from io import open
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||||
|
|
||||||
class TransfoXLTokenizationTest(unittest.TestCase):
|
class TransfoXLTokenizationTest(unittest.TestCase):
|
||||||
|
|
||||||
@@ -30,7 +29,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
|||||||
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
|
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
|
||||||
"running", ",", "low", "l",
|
"running", ",", "low", "l",
|
||||||
]
|
]
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with TemporaryDirectory() as tmpdirname:
|
||||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|||||||
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||||
|
|
||||||
class XLMTokenizationTest(unittest.TestCase):
|
class XLMTokenizationTest(unittest.TestCase):
|
||||||
|
|
||||||
@@ -34,7 +33,7 @@ class XLMTokenizationTest(unittest.TestCase):
|
|||||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with TemporaryDirectory() as tmpdirname:
|
||||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
with open(vocab_file, "w") as fp:
|
with open(vocab_file, "w") as fp:
|
||||||
|
|||||||
@@ -16,11 +16,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE, VOCAB_FILES_NAMES)
|
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
|
||||||
|
|
||||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||||
|
|
||||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||||
'fixtures/test_sentencepiece.model')
|
'fixtures/test_sentencepiece.model')
|
||||||
@@ -30,7 +29,7 @@ class XLNetTokenizationTest(unittest.TestCase):
|
|||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with TemporaryDirectory() as tmpdirname:
|
||||||
tokenizer.save_pretrained(tmpdirname)
|
tokenizer.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
|
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
|
||||||
|
|||||||
@@ -231,8 +231,7 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
# Add supplementary tokens.
|
# Add supplementary tokens.
|
||||||
if added_tokens_file is not None:
|
if added_tokens_file is not None:
|
||||||
added_tokens = json.load(open(added_tokens_file, encoding="utf-8"))
|
added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8"))
|
||||||
added_tok_encoder = dict((tok, len(tokenizer) + i) for i, tok in enumerate(added_tokens))
|
|
||||||
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
||||||
tokenizer.added_tokens_encoder.update(added_tok_encoder)
|
tokenizer.added_tokens_encoder.update(added_tok_encoder)
|
||||||
tokenizer.added_tokens_decoder.update(added_tok_decoder)
|
tokenizer.added_tokens_decoder.update(added_tok_decoder)
|
||||||
@@ -256,7 +255,11 @@ class PreTrainedTokenizer(object):
|
|||||||
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
|
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
|
||||||
|
|
||||||
with open(added_tokens_file, 'w', encoding='utf-8') as f:
|
with open(added_tokens_file, 'w', encoding='utf-8') as f:
|
||||||
f.write(json.dumps(self.added_tokens_decoder, ensure_ascii=False))
|
if self.added_tokens_encoder:
|
||||||
|
out_str = json.dumps(self.added_tokens_decoder, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
out_str = u"{}"
|
||||||
|
f.write(out_str)
|
||||||
|
|
||||||
vocab_files = self.save_vocabulary(save_directory)
|
vocab_files = self.save_vocabulary(save_directory)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user