clean up pretrained

This commit is contained in:
thomwolf
2019-07-26 17:09:21 +02:00
parent 57e54ec070
commit 27b0f86d36

View File

@@ -152,11 +152,13 @@ class PreTrainedTokenizer(object):
@classmethod @classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" """
Instantiate a PreTrainedTokenizer from pre-trained vocabulary files. Instantiate a PreTrainedTokenizer from pre-trained vocabulary files.
Download and cache the vocabulary files if needed. Download and cache the vocabulary files if needed.
""" """
cache_dir = kwargs.pop('cache_dir', None)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
@@ -308,7 +310,8 @@ class PreTrainedTokenizer(object):
to_add_tokens = [] to_add_tokens = []
for token in new_tokens: for token in new_tokens:
if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token): if token != self.unk_token and \
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
to_add_tokens.append(token) to_add_tokens.append(token)
logger.info("Adding %s to the vocabulary", token) logger.info("Adding %s to the vocabulary", token)