From 27b0f86d36a1ee25dcc70ba602aefa556dc5f0a9 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 26 Jul 2019 17:09:21 +0200 Subject: [PATCH] clean up pretrained --- pytorch_transformers/tokenization_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index f603a29d74..2b3219c4cc 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -152,11 +152,13 @@ class PreTrainedTokenizer(object): @classmethod - def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): """ Instantiate a PreTrainedTokenizer from pre-trained vocabulary files. Download and cache the vocabulary files if needed. """ + cache_dir = kwargs.pop('cache_dir', None) + s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} if pretrained_model_name_or_path in s3_models: @@ -308,7 +310,8 @@ class PreTrainedTokenizer(object): to_add_tokens = [] for token in new_tokens: - if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token): + if token != self.unk_token and \ + self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token): to_add_tokens.append(token) logger.info("Adding %s to the vocabulary", token)