more explicit variable name
This commit is contained in:
@@ -512,14 +512,14 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
|
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
||||||
from_tf=False, *inputs, **kwargs):
|
from_tf=False, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
pretrained_model_name: either:
|
pretrained_model_name_or_path: either:
|
||||||
- a str with the name of a pre-trained model to load selected in the list of:
|
- a str with the name of a pre-trained model to load selected in the list of:
|
||||||
. `bert-base-uncased`
|
. `bert-base-uncased`
|
||||||
. `bert-large-uncased`
|
. `bert-large-uncased`
|
||||||
@@ -540,10 +540,10 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
*inputs, **kwargs: additional input for the specific Bert class
|
*inputs, **kwargs: additional input for the specific Bert class
|
||||||
(ex: num_labels for BertForSequenceClassification)
|
(ex: num_labels for BertForSequenceClassification)
|
||||||
"""
|
"""
|
||||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
else:
|
else:
|
||||||
archive_file = pretrained_model_name
|
archive_file = pretrained_model_name_or_path
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||||
@@ -552,7 +552,7 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find any file "
|
"We assumed '{}' was a path or url but couldn't find any file "
|
||||||
"associated to this path or url.".format(
|
"associated to this path or url.".format(
|
||||||
pretrained_model_name,
|
pretrained_model_name_or_path,
|
||||||
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
||||||
archive_file))
|
archive_file))
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -418,14 +418,14 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls, pretrained_model_name, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
pretrained_model_name: either:
|
pretrained_model_name_or_path: either:
|
||||||
- a str with the name of a pre-trained model to load selected in the list of:
|
- a str with the name of a pre-trained model to load selected in the list of:
|
||||||
. `openai-gpt`
|
. `openai-gpt`
|
||||||
- a path or url to a pretrained model archive containing:
|
- a path or url to a pretrained model archive containing:
|
||||||
@@ -440,11 +440,11 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
*inputs, **kwargs: additional input for the specific Bert class
|
*inputs, **kwargs: additional input for the specific Bert class
|
||||||
(ex: num_labels for BertForSequenceClassification)
|
(ex: num_labels for BertForSequenceClassification)
|
||||||
"""
|
"""
|
||||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
else:
|
else:
|
||||||
archive_file = pretrained_model_name
|
archive_file = pretrained_model_name_or_path
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
@@ -455,7 +455,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||||
"at this path or url.".format(
|
"at this path or url.".format(
|
||||||
pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
|
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
|
||||||
archive_file, config_file
|
archive_file, config_file
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -116,15 +116,15 @@ class BertTokenizer(object):
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
"""
|
"""
|
||||||
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
|
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
else:
|
else:
|
||||||
vocab_file = pretrained_model_name
|
vocab_file = pretrained_model_name_or_path
|
||||||
if os.path.isdir(vocab_file):
|
if os.path.isdir(vocab_file):
|
||||||
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
|
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
@@ -135,7 +135,7 @@ class BertTokenizer(object):
|
|||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find any file "
|
"We assumed '{}' was a path or url but couldn't find any file "
|
||||||
"associated to this path or url.".format(
|
"associated to this path or url.".format(
|
||||||
pretrained_model_name,
|
pretrained_model_name_or_path,
|
||||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||||
vocab_file))
|
vocab_file))
|
||||||
return None
|
return None
|
||||||
@@ -144,10 +144,10 @@ class BertTokenizer(object):
|
|||||||
else:
|
else:
|
||||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||||
vocab_file, resolved_vocab_file))
|
vocab_file, resolved_vocab_file))
|
||||||
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||||
# than the number of positional embeddings
|
# than the number of positional embeddings
|
||||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
|
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||||
# Instantiate tokenizer.
|
# Instantiate tokenizer.
|
||||||
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user