clean up pretrained
This commit is contained in:
@@ -152,11 +152,13 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a PreTrainedTokenizer from pre-trained vocabulary files.
|
Instantiate a PreTrainedTokenizer from pre-trained vocabulary files.
|
||||||
Download and cache the vocabulary files if needed.
|
Download and cache the vocabulary files if needed.
|
||||||
"""
|
"""
|
||||||
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
|
|
||||||
s3_models = list(cls.max_model_input_sizes.keys())
|
s3_models = list(cls.max_model_input_sizes.keys())
|
||||||
vocab_files = {}
|
vocab_files = {}
|
||||||
if pretrained_model_name_or_path in s3_models:
|
if pretrained_model_name_or_path in s3_models:
|
||||||
@@ -308,7 +310,8 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
to_add_tokens = []
|
to_add_tokens = []
|
||||||
for token in new_tokens:
|
for token in new_tokens:
|
||||||
if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
|
if token != self.unk_token and \
|
||||||
|
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
|
||||||
to_add_tokens.append(token)
|
to_add_tokens.append(token)
|
||||||
logger.info("Adding %s to the vocabulary", token)
|
logger.info("Adding %s to the vocabulary", token)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user