unified tokenizer api and serialization + tests
This commit is contained in:
@@ -16,37 +16,145 @@
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import regex as re
|
||||
import json
|
||||
import six
|
||||
from io import open
|
||||
|
||||
try:
|
||||
from functools import lru_cache
|
||||
except ImportError:
|
||||
# Just a dummy decorator to get the checks to run on python2
|
||||
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
|
||||
def lru_cache():
|
||||
return lambda func: func
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
|
||||
ADDED_TOKENS_FILE = 'added_tokens.json'
|
||||
|
||||
class PreTrainedTokenizer(object):
|
||||
""" An abstract class to handle dowloading and loading pretrained tokenizers.
|
||||
""" An abstract class to handle dowloading and loading pretrained tokenizers and adding tokens to the vocabulary.
|
||||
|
||||
Derived class can set up a few special tokens to be used in common scripts and internals:
|
||||
bos_token, eos_token, EOP_TOKEN, EOD_TOKEN, unk_token, sep_token, pad_token, cls_token, mask_token
|
||||
additional_special_tokens = []
|
||||
|
||||
We defined an added_tokens_encoder to add new tokens to the vocabulary without having to handle the
|
||||
specific vocabulary augmentation methods of the various underlying dictionnary structures (BPE, sentencepiece...).
|
||||
"""
|
||||
vocab_files_names = {}
|
||||
pretrained_vocab_files_map = {}
|
||||
max_model_input_sizes = {}
|
||||
|
||||
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
|
||||
"pad_token", "cls_token", "mask_token",
|
||||
"additional_special_tokens"]
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
if self._bos_token is None:
|
||||
logger.error("Using bos_token, but it is not set yet.")
|
||||
return self._bos_token
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
if self._eos_token is None:
|
||||
logger.error("Using eos_token, but it is not set yet.")
|
||||
return self._eos_token
|
||||
|
||||
@property
|
||||
def unk_token(self):
|
||||
if self._unk_token is None:
|
||||
logger.error("Using unk_token, but it is not set yet.")
|
||||
return self._unk_token
|
||||
|
||||
@property
|
||||
def sep_token(self):
|
||||
if self._sep_token is None:
|
||||
logger.error("Using sep_token, but it is not set yet.")
|
||||
return self._sep_token
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
if self._pad_token is None:
|
||||
logger.error("Using pad_token, but it is not set yet.")
|
||||
return self._pad_token
|
||||
|
||||
@property
|
||||
def cls_token(self):
|
||||
if self._cls_token is None:
|
||||
logger.error("Using cls_token, but it is not set yet.")
|
||||
return self._cls_token
|
||||
|
||||
@property
|
||||
def mask_token(self):
|
||||
if self._mask_token is None:
|
||||
logger.error("Using mask_token, but it is not set yet.")
|
||||
return self._mask_token
|
||||
|
||||
@property
|
||||
def additional_special_tokens(self):
|
||||
if self._additional_special_tokens is None:
|
||||
logger.error("Using additional_special_tokens, but it is not set yet.")
|
||||
return self._additional_special_tokens
|
||||
|
||||
@bos_token.setter
|
||||
def bos_token(self, value):
|
||||
self._bos_token = value
|
||||
|
||||
@eos_token.setter
|
||||
def eos_token(self, value):
|
||||
self._eos_token = value
|
||||
|
||||
@unk_token.setter
|
||||
def unk_token(self, value):
|
||||
self._unk_token = value
|
||||
|
||||
@sep_token.setter
|
||||
def sep_token(self, value):
|
||||
self._sep_token = value
|
||||
|
||||
@pad_token.setter
|
||||
def pad_token(self, value):
|
||||
self._pad_token = value
|
||||
|
||||
@cls_token.setter
|
||||
def cls_token(self, value):
|
||||
self._cls_token = value
|
||||
|
||||
@mask_token.setter
|
||||
def mask_token(self, value):
|
||||
self._mask_token = value
|
||||
|
||||
@additional_special_tokens.setter
|
||||
def additional_special_tokens(self, value):
|
||||
self._additional_special_tokens = value
|
||||
|
||||
def __init__(self, max_len=None, **kwargs):
|
||||
self._bos_token = None
|
||||
self._eos_token = None
|
||||
self._unk_token = None
|
||||
self._sep_token = None
|
||||
self._pad_token = None
|
||||
self._cls_token = None
|
||||
self._mask_token = None
|
||||
self._additional_special_tokens = []
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.added_tokens_encoder = {}
|
||||
self.added_tokens_decoder = {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||
raise ValueError(
|
||||
"PreTrainedTokenizer.__init__() argument {} should be in {}".format(
|
||||
key, ', '.join(self.SPECIAL_TOKENS_ATTRIBUTES)))
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *inputs, **kwargs):
|
||||
return cls._from_pretrained(*inputs, **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
@@ -59,16 +167,20 @@ class PreTrainedTokenizer(object):
|
||||
for file_id, map_list in cls.pretrained_vocab_files_map.items():
|
||||
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
|
||||
else:
|
||||
for file_id, file_name in cls.vocab_files_names.items():
|
||||
all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
|
||||
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE}
|
||||
all_vocab_files_names.update(cls.vocab_files_names)
|
||||
for file_id, file_name in all_vocab_files_names.items():
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
||||
else:
|
||||
full_file_name = pretrained_model_name_or_path
|
||||
if not os.path.exists(full_file_name):
|
||||
logger.info("Didn't find file {}. We don't load it.".format(full_file_name))
|
||||
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
||||
full_file_name = None
|
||||
vocab_files[file_id] = full_file_name
|
||||
# redirect to the cache, if necessary
|
||||
|
||||
# Get files from url, cache, or disk depending on the case
|
||||
try:
|
||||
resolved_vocab_files = {}
|
||||
for file_id, file_path in vocab_files.items():
|
||||
@@ -95,6 +207,7 @@ class PreTrainedTokenizer(object):
|
||||
logger.info("loading file {} from cache at {}".format(
|
||||
file_path, resolved_vocab_files[file_id]))
|
||||
|
||||
# Set max length if needed
|
||||
if pretrained_model_name_or_path in cls.max_model_input_sizes:
|
||||
# if we're using a pretrained model, ensure the tokenizer
|
||||
# wont index sequences longer than the number of positional embeddings
|
||||
@@ -102,31 +215,255 @@ class PreTrainedTokenizer(object):
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
|
||||
# Merge resolved_vocab_files arguments in kwargs.
|
||||
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
|
||||
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
|
||||
for args_name, file_path in resolved_vocab_files.items():
|
||||
kwargs[args_name] = file_path
|
||||
if args_name not in kwargs:
|
||||
kwargs[args_name] = file_path
|
||||
if special_tokens_map_file is not None:
|
||||
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
|
||||
for key, value in special_tokens_map.items():
|
||||
if key not in kwargs:
|
||||
kwargs[key] = value
|
||||
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(*inputs, **kwargs)
|
||||
|
||||
# Add supplementary tokens.
|
||||
if added_tokens_file is not None:
|
||||
added_tokens = json.load(open(added_tokens_file, encoding="utf-8"))
|
||||
added_tok_encoder = dict((tok, len(tokenizer) + i) for i, tok in enumerate(added_tokens))
|
||||
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
||||
tokenizer.added_tokens_encoder.update(added_tok_encoder)
|
||||
tokenizer.added_tokens_decoder.update(added_tok_decoder)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text):
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save the tokenizer vocabulary files (with added tokens) and the
|
||||
special-tokens-to-class-attributes-mapping to a directory, so that it
|
||||
can be re-loaded using the `from_pretrained(save_directory)` class method.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Saving directory ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
|
||||
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
|
||||
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
|
||||
|
||||
with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
|
||||
|
||||
with open(added_tokens_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.added_tokens_decoder, ensure_ascii=False))
|
||||
|
||||
vocab_files = self.save_vocabulary(save_directory)
|
||||
|
||||
return vocab_files + (special_tokens_map_file, added_tokens_file)
|
||||
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
""" Save the tokenizer vocabulary to a directory. This method doesn't save added tokens
|
||||
and special token mappings.
|
||||
|
||||
Please use `save_pretrained()` to save the full Tokenizer state so that it can be
|
||||
reloaded using the `from_pretrained(save_directory)` class method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def vocab_size(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.vocab_size + len(self.added_tokens_encoder)
|
||||
|
||||
|
||||
def add_tokens(self, new_tokens):
|
||||
""" Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
vocabulary, they are added to the added_tokens_encoder with indices starting from
|
||||
the last index of the current vocabulary.
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary which can be used to correspondingly
|
||||
increase the size of the associated model embedding matrices.
|
||||
"""
|
||||
if not new_tokens:
|
||||
return 0
|
||||
|
||||
to_add_tokens = []
|
||||
for token in new_tokens:
|
||||
if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
|
||||
to_add_tokens.append(token)
|
||||
logger.info("Adding %s to the vocabulary", token)
|
||||
|
||||
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
|
||||
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
||||
self.added_tokens_encoder.update(added_tok_encoder)
|
||||
self.added_tokens_decoder.update(added_tok_decoder)
|
||||
|
||||
return len(to_add_tokens)
|
||||
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict):
|
||||
""" Add a dictionnary of special tokens (eos, pad, cls...) to the encoder and link them
|
||||
to class attributes. If the special tokens are not in the vocabulary, they are added
|
||||
to it and indexed starting from the last index of the current vocabulary.
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary which can be used to correspondingly
|
||||
increase the size of the associated model embedding matrices.
|
||||
"""
|
||||
if not special_tokens_dict:
|
||||
return 0
|
||||
|
||||
added_special_tokens = self.add_tokens(special_tokens_dict.values())
|
||||
for key, value in special_tokens_dict.items():
|
||||
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
||||
setattr(self, key, value)
|
||||
|
||||
return added_special_tokens
|
||||
|
||||
|
||||
def tokenize(self, text, **kwargs):
|
||||
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
||||
Split in words for word-based vocabulary or sub-words for sub-word-based
|
||||
vocabularies (BPE/SentencePieces/WordPieces).
|
||||
|
||||
Take care of added tokens.
|
||||
"""
|
||||
def split_on_tokens(tok_list, text):
|
||||
if not text:
|
||||
return []
|
||||
if not tok_list:
|
||||
return self._tokenize(text, **kwargs)
|
||||
tok = tok_list[0]
|
||||
split_text = text.split(tok)
|
||||
return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
|
||||
for sub_text in split_text), [])[:-1]
|
||||
|
||||
added_tokens = list(self.added_tokens_encoder.keys())
|
||||
tokenized_text = split_on_tokens(added_tokens, text)
|
||||
return tokenized_text
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
||||
Split in words for word-based vocabulary or sub-words for sub-word-based
|
||||
vocabularies (BPE/SentencePieces/WordPieces).
|
||||
|
||||
Don't take care of added tokens.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a single token or a sequence of tokens (str/unicode) in a integer id
|
||||
(resp.) a sequence of ids, using the vocabulary.
|
||||
"""
|
||||
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
||||
return self.convert_token_to_id_with_added_voc(tokens)
|
||||
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(self.convert_token_to_id_with_added_voc(token))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(ids), self.max_len))
|
||||
return ids
|
||||
|
||||
|
||||
def convert_token_to_id_with_added_voc(self, token):
|
||||
if token in self.added_tokens_encoder:
|
||||
return self.added_tokens_encoder[token]
|
||||
return self._convert_token_to_id(token)
|
||||
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
|
||||
|
||||
Args:
|
||||
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
||||
"""
|
||||
if isinstance(ids, int):
|
||||
return self.convert_id_to_token(ids)
|
||||
tokens = []
|
||||
for index in ids:
|
||||
if index in self.all_special_ids and skip_special_tokens:
|
||||
continue
|
||||
if index in self.added_tokens_decoder:
|
||||
tokens.append(self.added_tokens_decoder[index])
|
||||
else:
|
||||
tokens.append(self._convert_id_to_token(index))
|
||||
return tokens
|
||||
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def encode(self, text):
|
||||
raise NotImplementedError
|
||||
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
same as self.convert_tokens_to_ids(self.tokenize(text)).
|
||||
"""
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
def decode(self, token_ids, *input, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
raise NotImplementedError
|
||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||
with options to remove special tokens and clean up tokenization spaces.
|
||||
"""
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
text = self._convert_ids_to_string(filtered_tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
text = clean_up_tokenization(text)
|
||||
return text
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary.
|
||||
roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)).
|
||||
"""
|
||||
return ' '.join(self.convert_ids_to_tokens(tokens_ids))
|
||||
|
||||
@property
|
||||
def special_tokens_map(self):
|
||||
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
|
||||
values ('<unk>', '<cls>'...)
|
||||
"""
|
||||
set_attr = {}
|
||||
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||
attr_value = getattr(self, "_" + attr)
|
||||
if attr_value:
|
||||
set_attr[attr] = attr_value
|
||||
return set_attr
|
||||
|
||||
@property
|
||||
def all_special_tokens(self):
|
||||
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
|
||||
(cls_token, unk_token...).
|
||||
"""
|
||||
all_toks = []
|
||||
set_attr = self.special_tokens_map
|
||||
for attr_value in set_attr.values():
|
||||
all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value])
|
||||
all_toks = list(set(all_toks))
|
||||
return all_toks
|
||||
|
||||
@property
|
||||
def all_special_ids(self):
|
||||
""" List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
|
||||
class attributes (cls_token, unk_token...).
|
||||
"""
|
||||
all_toks = self.all_special_tokens
|
||||
all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
|
||||
return all_ids
|
||||
|
||||
|
||||
|
||||
def clean_up_tokenization(out_string):
|
||||
|
||||
Reference in New Issue
Block a user