more explicit variable name
This commit is contained in:
@@ -512,14 +512,14 @@ class BertPreTrainedModel(nn.Module):
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
||||
from_tf=False, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `bert-base-uncased`
|
||||
. `bert-large-uncased`
|
||||
@@ -540,10 +540,10 @@ class BertPreTrainedModel(nn.Module):
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_labels for BertForSequenceClassification)
|
||||
"""
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
archive_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
@@ -552,7 +552,7 @@ class BertPreTrainedModel(nn.Module):
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
||||
archive_file))
|
||||
return None
|
||||
|
||||
@@ -418,14 +418,14 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
||||
cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
||||
):
|
||||
"""
|
||||
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `openai-gpt`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
@@ -440,11 +440,11 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_labels for BertForSequenceClassification)
|
||||
"""
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
archive_file = pretrained_model_name_or_path
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
@@ -455,7 +455,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
|
||||
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
|
||||
archive_file, config_file
|
||||
)
|
||||
)
|
||||
|
||||
@@ -116,15 +116,15 @@ class BertTokenizer(object):
|
||||
return tokens
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
vocab_file = pretrained_model_name
|
||||
vocab_file = pretrained_model_name_or_path
|
||||
if os.path.isdir(vocab_file):
|
||||
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
@@ -135,7 +135,7 @@ class BertTokenizer(object):
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
vocab_file))
|
||||
return None
|
||||
@@ -144,10 +144,10 @@ class BertTokenizer(object):
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user