added cache_dir option in from_pretrained
This commit is contained in:
@@ -443,7 +443,7 @@ class PreTrainedBertModel(nn.Module):
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, *inputs, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
@@ -468,7 +468,7 @@ class PreTrainedBertModel(nn.Module):
|
||||
archive_file = pretrained_model_name
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file)
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
|
||||
Reference in New Issue
Block a user