From d6f06c03f4658a80bef76ae226494864b476e391 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Nov 2018 14:09:06 +0100 Subject: [PATCH] fixed loading pre-trained tokenizer from directory --- pytorch_pretrained_bert/modeling.py | 2 +- pytorch_pretrained_bert/tokenization.py | 25 ++++++++++++++----------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 70a41d91e7..e8ad26a1c6 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -478,7 +478,7 @@ class PreTrainedBertModel(nn.Module): "associated to this path or url.".format( pretrained_model_name, ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), - pretrained_model_name)) + archive_file)) return None if resolved_archive_file == archive_file: logger.info("loading archive file {}".format(archive_file)) diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index fefdaa54a0..c7ef20ddef 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -39,6 +39,7 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = { 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", } +VOCAB_NAME = 'vocab.txt' def load_vocab(vocab_file): @@ -100,7 +101,7 @@ class BertTokenizer(object): return tokens @classmethod - def from_pretrained(cls, pretrained_model_name, do_lower_case=True): + def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. @@ -109,16 +110,11 @@ class BertTokenizer(object): vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] else: vocab_file = pretrained_model_name + if os.path.isdir(vocab_file): + vocab_file = os.path.join(vocab_file, VOCAB_NAME) # redirect to the cache, if necessary try: - resolved_vocab_file = cached_path(vocab_file) - if resolved_vocab_file == vocab_file: - logger.info("loading vocabulary file {}".format(vocab_file)) - else: - logger.info("loading vocabulary file {} from cache at {}".format( - vocab_file, resolved_vocab_file)) - # Instantiate tokenizer. - tokenizer = cls(resolved_vocab_file, do_lower_case) + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list ({}). " @@ -126,8 +122,15 @@ class BertTokenizer(object): "associated to this path or url.".format( pretrained_model_name, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - pretrained_model_name)) - tokenizer = None + vocab_file)) + return None + if resolved_vocab_file == vocab_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) return tokenizer