more explicit variable name

This commit is contained in:
thomwolf
2019-02-08 09:54:49 +01:00
parent 6bc082da0a
commit edcb56fd96
3 changed files with 19 additions and 19 deletions

View File

@@ -512,14 +512,14 @@ class BertPreTrainedModel(nn.Module):
module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
from_tf=False, *inputs, **kwargs):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
@@ -540,10 +540,10 @@ class BertPreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
archive_file = pretrained_model_name
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
@@ -552,7 +552,7 @@ class BertPreTrainedModel(nn.Module):
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None