unified tokenizer api and serialization + tests
This commit is contained in:
@@ -16,17 +16,13 @@
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from shutil import copyfile
|
||||
from io import open
|
||||
|
||||
import unicodedata
|
||||
import six
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,8 +40,6 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'xlnet-large-cased': 512,
|
||||
}
|
||||
|
||||
VOCAB_NAME = 'spiece.model'
|
||||
|
||||
SPIECE_UNDERLINE = u'▁'
|
||||
|
||||
# Segments (not really needed)
|
||||
@@ -60,31 +54,26 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
SentencePiece based tokenizer. Peculiarities:
|
||||
- requires SentencePiece: https://github.com/google/sentencepiece
|
||||
"""
|
||||
# Tokens
|
||||
special_symbols = {
|
||||
"<unk>" : 0,
|
||||
"<s>" : 1,
|
||||
"</s>" : 2,
|
||||
"<cls>" : 3,
|
||||
"<sep>" : 4,
|
||||
"<pad>" : 5,
|
||||
"<mask>" : 6,
|
||||
"<eod>" : 7,
|
||||
"<eop>" : 8,
|
||||
}
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, max_len=None,
|
||||
do_lower_case=False, remove_space=True, keep_accents=False):
|
||||
do_lower_case=False, remove_space=True, keep_accents=False,
|
||||
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
|
||||
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
|
||||
additional_special_tokens=["<eop>", "<eod>"], **kwargs):
|
||||
super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token,
|
||||
unk_token=unk_token, sep_token=sep_token,
|
||||
pad_token=pad_token, cls_token=cls_token,
|
||||
mask_token=mask_token, additional_special_tokens=
|
||||
additional_special_tokens, **kwargs)
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
except ImportError:
|
||||
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
|
||||
"pip install sentencepiece")
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.do_lower_case = do_lower_case
|
||||
self.remove_space = remove_space
|
||||
self.keep_accents = keep_accents
|
||||
@@ -94,46 +83,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
self.sp_model.Load(vocab_file)
|
||||
|
||||
@property
|
||||
def UNK_TOKEN(self):
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def SEP_TOKEN(self):
|
||||
return "<sep>"
|
||||
|
||||
@property
|
||||
def PAD_TOKEN(self):
|
||||
return "<pad>"
|
||||
|
||||
@property
|
||||
def CLS_TOKEN(self):
|
||||
return "<cls>"
|
||||
|
||||
@property
|
||||
def MASK_TOKEN(self):
|
||||
return "<mask>"
|
||||
|
||||
@property
|
||||
def UNK_ID(self):
|
||||
return self.special_symbols["<unk>"]
|
||||
|
||||
@property
|
||||
def SEP_ID(self):
|
||||
return self.special_symbols["<sep>"]
|
||||
|
||||
@property
|
||||
def PAD_ID(self):
|
||||
return self.special_symbols["<pad>"]
|
||||
|
||||
@property
|
||||
def CLS_ID(self):
|
||||
return self.special_symbols["<cls>"]
|
||||
|
||||
@property
|
||||
def MASK_ID(self):
|
||||
return self.special_symbols["<mask>"]
|
||||
|
||||
def __len__(self):
|
||||
def vocab_size(self):
|
||||
return len(self.sp_model)
|
||||
|
||||
def __getstate__(self):
|
||||
@@ -169,7 +119,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
|
||||
return outputs
|
||||
|
||||
def tokenize(self, text, return_unicode=True, sample=False):
|
||||
def _tokenize(self, text, return_unicode=True, sample=False):
|
||||
""" Tokenize a string.
|
||||
return_unicode is used only for py2
|
||||
"""
|
||||
@@ -208,56 +158,30 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
|
||||
return new_pieces
|
||||
|
||||
def convert_tokens_to_ids(self, tokens, sample=False):
|
||||
""" Converts a sequence of tokens into ids using the vocab. """
|
||||
ids = []
|
||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||
return self.sp_model.PieceToId(tokens)
|
||||
for token in tokens:
|
||||
ids.append(self.sp_model.PieceToId(token))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this XLNet model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.sp_model.PieceToId(token)
|
||||
|
||||
def convert_ids_to_tokens(self, ids, return_unicode=True):
|
||||
"""Converts a sequence of ids in tokens."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
tokens.append(self.sp_model.IdToPiece(i))
|
||||
def _convert_id_to_token(self, index, return_unicode=True):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
if six.PY2 and return_unicode and isinstance(token, str):
|
||||
token = token.decode('utf-8')
|
||||
return token
|
||||
|
||||
if six.PY2 and return_unicode:
|
||||
ret_pieces = []
|
||||
for piece in tokens:
|
||||
if isinstance(piece, str):
|
||||
piece = piece.decode('utf-8')
|
||||
ret_pieces.append(piece)
|
||||
tokens = ret_pieces
|
||||
return tokens
|
||||
|
||||
def encode(self, text, sample=False):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text, sample=sample))
|
||||
|
||||
def decode(self, ids, clean_up_tokenization_spaces=True):
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(ids)
|
||||
out_string = ''.join(tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
out_string = out_string.strip().replace('<unk>', '')
|
||||
out_string = clean_up_tokenization(out_string)
|
||||
out_string = ''.join(tokens_ids)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
def save_vocabulary(self, save_directory):
|
||||
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||
to a directory.
|
||||
"""
|
||||
if not os.path.isdir(vocab_path):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
out_vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user