From edcb56fd96958984fc4bce93e931d5bac61d41c4 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 8 Feb 2019 09:54:49 +0100 Subject: [PATCH] more explicit variable name --- pytorch_pretrained_bert/modeling.py | 12 ++++++------ pytorch_pretrained_bert/modeling_openai.py | 12 ++++++------ pytorch_pretrained_bert/tokenization.py | 14 +++++++------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 0d68c2691c..e630765782 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -512,14 +512,14 @@ class BertPreTrainedModel(nn.Module): module.bias.data.zero_() @classmethod - def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, + def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs): """ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: - pretrained_model_name: either: + pretrained_model_name_or_path: either: - a str with the name of a pre-trained model to load selected in the list of: . `bert-base-uncased` . `bert-large-uncased` @@ -540,10 +540,10 @@ class BertPreTrainedModel(nn.Module): *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ - if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: - archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] else: - archive_file = pretrained_model_name + archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) @@ -552,7 +552,7 @@ class BertPreTrainedModel(nn.Module): "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( - pretrained_model_name, + pretrained_model_name_or_path, ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file)) return None diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 7e4cd63bba..ac2fb03910 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -418,14 +418,14 @@ class OpenAIGPTPreTrainedModel(nn.Module): @classmethod def from_pretrained( - cls, pretrained_model_name, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs + cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs ): """ Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: - pretrained_model_name: either: + pretrained_model_name_or_path: either: - a str with the name of a pre-trained model to load selected in the list of: . `openai-gpt` - a path or url to a pretrained model archive containing: @@ -440,11 +440,11 @@ class OpenAIGPTPreTrainedModel(nn.Module): *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ - if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: - archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] else: - archive_file = pretrained_model_name + archive_file = pretrained_model_name_or_path config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) # redirect to the cache, if necessary try: @@ -455,7 +455,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find files {} and {} " "at this path or url.".format( - pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, + pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, archive_file, config_file ) ) diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index 65c1b56d38..d64d3512f3 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -116,15 +116,15 @@ class BertTokenizer(object): return tokens @classmethod - def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. """ - if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] else: - vocab_file = pretrained_model_name + vocab_file = pretrained_model_name_or_path if os.path.isdir(vocab_file): vocab_file = os.path.join(vocab_file, VOCAB_NAME) # redirect to the cache, if necessary @@ -135,7 +135,7 @@ class BertTokenizer(object): "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( - pretrained_model_name, + pretrained_model_name_or_path, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), vocab_file)) return None @@ -144,10 +144,10 @@ class BertTokenizer(object): else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) - if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: # if we're using a pretrained model, ensure the tokenizer wont index sequences longer # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)