added cache_dir option in from_pretrained
This commit is contained in:
@@ -443,7 +443,7 @@ class PreTrainedBertModel(nn.Module):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
@@ -468,7 +468,7 @@ class PreTrainedBertModel(nn.Module):
|
|||||||
archive_file = pretrained_model_name
|
archive_file = pretrained_model_name
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(archive_file)
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
|
|||||||
Reference in New Issue
Block a user