Add more tests on tokenizers serialization - fix bugs (#5056)
* update tests for fast tokenizers + fix small bug in saving/loading * better tests on serialization * fixing serialization * comment cleanup
This commit is contained in:
@@ -62,9 +62,12 @@ PreTokenizedInputPair = Tuple[List[str], List[str]]
|
||||
EncodedInputPair = Tuple[List[int], List[int]]
|
||||
|
||||
|
||||
# Slow tokenizers used to be saved in three separated files
|
||||
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
|
||||
ADDED_TOKENS_FILE = "added_tokens.json"
|
||||
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
||||
|
||||
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
|
||||
FULL_TOKENIZER_FILE = "tokenizer.json"
|
||||
|
||||
|
||||
@@ -589,10 +592,14 @@ class SpecialTokensMixin:
|
||||
self._additional_special_tokens = []
|
||||
self.verbose = verbose
|
||||
|
||||
# We directly set the hidden value to allow initialization with special tokens
|
||||
# which are not yet in the vocabulary. Necesssary for serialization/de-serialization
|
||||
# TODO clean this up at some point (probably by sitching to fast tokenizers)
|
||||
for key, value in kwargs.items():
|
||||
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||
if key == "additional_special_tokens":
|
||||
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
|
||||
setattr(self, key, value)
|
||||
elif isinstance(value, (str, AddedToken)):
|
||||
setattr(self, key, value)
|
||||
else:
|
||||
@@ -607,7 +614,7 @@ class SpecialTokensMixin:
|
||||
Return:
|
||||
Number of tokens added in the vocaulary during the operation.
|
||||
"""
|
||||
return self.add_tokens(self.all_special_tokens_extended, special_token=True)
|
||||
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict):
|
||||
"""
|
||||
@@ -652,22 +659,56 @@ class SpecialTokensMixin:
|
||||
added_tokens = 0
|
||||
for key, value in special_tokens_dict.items():
|
||||
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
|
||||
|
||||
if self.verbose:
|
||||
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
||||
setattr(self, key, value)
|
||||
|
||||
if key == "additional_special_tokens":
|
||||
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
|
||||
added_tokens += self.add_tokens(value, special_token=True)
|
||||
added_tokens += self.add_tokens(value, special_tokens=True)
|
||||
else:
|
||||
assert isinstance(value, str)
|
||||
added_tokens += self.add_tokens([value], special_token=True)
|
||||
added_tokens += self.add_tokens([value], special_tokens=True)
|
||||
|
||||
return added_tokens
|
||||
|
||||
def add_tokens(self, value, special_token=False):
|
||||
""" To be overriden by derived class to add a token in the vocabulary. """
|
||||
pass
|
||||
def add_tokens(self, new_tokens: Union[str, AddedToken, List[str], List[AddedToken]], special_tokens=False) -> int:
|
||||
"""
|
||||
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
vocabulary, they are added to it with indices starting from length of the current vocabulary.
|
||||
|
||||
Args:
|
||||
new_tokens: string or list of string or :class:`~transformers.AddedToken`. Each string is a token to add.
|
||||
Tokens are only added if they are not already in the vocabulary. AddedToken wrap a string token to
|
||||
let you personnalize it's behavior (Whether this token should only match against single word, whether
|
||||
this token should strip all potential whitespaces on the left side, Whether this token should strip
|
||||
all potential whitespaces on the right side...).
|
||||
special_token: can be used to specify if the token is a special token. This mostly change the normalization
|
||||
behavior (special tokens like CLS or [MASK] are usually not lower-cased for instance)
|
||||
|
||||
See details for :class:`~transformers.AddedToken` in HuggingFace tokenizers library.
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary.
|
||||
|
||||
Examples::
|
||||
|
||||
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
|
||||
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
||||
print('We have added', num_added_toks, 'tokens')
|
||||
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||
"""
|
||||
if not new_tokens:
|
||||
return 0
|
||||
|
||||
if not isinstance(new_tokens, (list, tuple)):
|
||||
new_tokens = [new_tokens]
|
||||
|
||||
return self._add_tokens(new_tokens, special_tokens=special_tokens)
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
@@ -964,11 +1005,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
|
||||
padding_side: str = "right"
|
||||
|
||||
def __init__(self, model_max_length=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
|
||||
self.init_inputs = ()
|
||||
self.init_kwargs = kwargs
|
||||
|
||||
# For backward compatibility we fallback to set model_max_length from max_len if provided
|
||||
model_max_length = model_max_length if model_max_length is not None else kwargs.pop("max_len", None)
|
||||
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
|
||||
self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
|
||||
|
||||
# Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.
|
||||
@@ -979,9 +1022,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
|
||||
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
|
||||
|
||||
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
|
||||
self.init_inputs = ()
|
||||
self.init_kwargs = {}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def max_len(self) -> int:
|
||||
@@ -1125,8 +1166,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
"added_tokens_file": ADDED_TOKENS_FILE,
|
||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||
"full_tokenizer_file": FULL_TOKENIZER_FILE,
|
||||
}
|
||||
# Look for the tokenizer main vocabulary files + the additional tokens files
|
||||
# Look for the tokenizer files
|
||||
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
||||
@@ -1215,18 +1257,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
|
||||
# Merge resolved_vocab_files arguments in init_kwargs.
|
||||
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
|
||||
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
|
||||
for args_name, file_path in resolved_vocab_files.items():
|
||||
if args_name not in init_kwargs:
|
||||
init_kwargs[args_name] = file_path
|
||||
if special_tokens_map_file is not None:
|
||||
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
|
||||
special_tokens_map = json.load(special_tokens_map_handle)
|
||||
for key, value in special_tokens_map.items():
|
||||
if isinstance(value, dict):
|
||||
value = AddedToken(**value)
|
||||
if key not in init_kwargs:
|
||||
init_kwargs[key] = value
|
||||
|
||||
# Instantiate tokenizer.
|
||||
try:
|
||||
@@ -1241,20 +1274,39 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
tokenizer.init_inputs = init_inputs
|
||||
tokenizer.init_kwargs = init_kwargs
|
||||
|
||||
# update unique_added_tokens_encoder with special tokens for correct tokenization
|
||||
if hasattr(tokenizer, "unique_added_tokens_encoder"):
|
||||
union = set(tokenizer.unique_added_tokens_encoder).union(tokenizer.all_special_tokens)
|
||||
tokenizer.unique_added_tokens_encoder = list(union)
|
||||
# If there is a complementary special token map, load it
|
||||
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
|
||||
if special_tokens_map_file is not None:
|
||||
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
|
||||
special_tokens_map = json.load(special_tokens_map_handle)
|
||||
|
||||
for key, value in special_tokens_map.items():
|
||||
if isinstance(value, dict):
|
||||
value = AddedToken(**value)
|
||||
setattr(tokenizer, key, value)
|
||||
|
||||
# Add supplementary tokens.
|
||||
special_tokens = tokenizer.all_special_tokens
|
||||
if added_tokens_file is not None:
|
||||
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
|
||||
added_tok_encoder = json.load(added_tokens_handle)
|
||||
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
|
||||
tokenizer.added_tokens_encoder.update(added_tok_encoder)
|
||||
tokenizer.added_tokens_decoder.update(added_tok_decoder)
|
||||
union = set(tokenizer.unique_added_tokens_encoder).union(tokenizer.added_tokens_encoder.keys())
|
||||
tokenizer.unique_added_tokens_encoder = list(union)
|
||||
|
||||
# Sort added tokens by index
|
||||
added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
|
||||
|
||||
for token, index in added_tok_encoder_sorted:
|
||||
assert index == len(tokenizer), (
|
||||
f"Non-consecutive added token '{token}' found. "
|
||||
f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
|
||||
)
|
||||
tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))
|
||||
|
||||
# Check all our special tokens are registrered as "no split" token (we don't cut them) and are in the vocab
|
||||
added_tokens = tokenizer.sanitize_special_tokens()
|
||||
if added_tokens:
|
||||
logger.warning(
|
||||
"Special tokens have been added in the vocabulary, make sure the associated word emebedding are fine-tuned or trained."
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
@@ -1296,9 +1348,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
write_dict[key] = value
|
||||
f.write(json.dumps(write_dict, ensure_ascii=False))
|
||||
|
||||
if hasattr(self, "added_tokens_encoder") and len(self.added_tokens_encoder) > 0:
|
||||
added_vocab = self.get_added_vocab()
|
||||
if added_vocab:
|
||||
with open(added_tokens_file, "w", encoding="utf-8") as f:
|
||||
out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
|
||||
out_str = json.dumps(added_vocab, ensure_ascii=False)
|
||||
f.write(out_str)
|
||||
|
||||
vocab_files = self.save_vocabulary(save_directory)
|
||||
|
||||
Reference in New Issue
Block a user