cleaning up tokenizer tests structure (at last) - last remaining ppb refs

This commit is contained in:
thomwolf
2019-08-05 14:08:56 +02:00
parent 00132b7a7a
commit 328afb7097
16 changed files with 332 additions and 233 deletions

View File

@@ -125,42 +125,34 @@ class PreTrainedTokenizer(object):
@bos_token.setter
def bos_token(self, value):
self.add_tokens([value])
self._bos_token = value
@eos_token.setter
def eos_token(self, value):
self.add_tokens([value])
self._eos_token = value
@unk_token.setter
def unk_token(self, value):
self.add_tokens([value])
self._unk_token = value
@sep_token.setter
def sep_token(self, value):
self.add_tokens([value])
self._sep_token = value
@pad_token.setter
def pad_token(self, value):
self.add_tokens([value])
self._pad_token = value
@cls_token.setter
def cls_token(self, value):
self.add_tokens([value])
self._cls_token = value
@mask_token.setter
def mask_token(self, value):
self.add_tokens([value])
self._mask_token = value
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self.add_tokens(value)
self._additional_special_tokens = value
def __init__(self, max_len=None, **kwargs):
@@ -179,6 +171,10 @@ class PreTrainedTokenizer(object):
for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
setattr(self, key, value)
@@ -415,15 +411,39 @@ class PreTrainedTokenizer(object):
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns:
Number of tokens added to the vocabulary.
Examples::
# Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
special_tokens_dict = {'cls_token': '<CLS>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
assert tokenizer.cls_token == '<CLS>'
"""
if not special_tokens_dict:
return 0
added_tokens = 0
for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
added_tokens += self.add_tokens(value)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
added_tokens += self.add_tokens([value])
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value)
return added_tokens
def tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer.