From 63ae5d2134d6c66e16affaac0983a99d3c073e41 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 26 Nov 2018 10:21:56 +0100 Subject: [PATCH] added cache_dir option in from_pretrained --- pytorch_pretrained_bert/modeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index c46ba13b01..2d6dfa531d 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -443,7 +443,7 @@ class PreTrainedBertModel(nn.Module): module.bias.data.zero_() @classmethod - def from_pretrained(cls, pretrained_model_name, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. @@ -468,7 +468,7 @@ class PreTrainedBertModel(nn.Module): archive_file = pretrained_model_name # redirect to the cache, if necessary try: - resolved_archive_file = cached_path(archive_file) + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list ({}). "