From 3963d57c89a0a49c5a2e9dd17958ba29934873e0 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Sat, 27 Apr 2019 10:52:53 -0700 Subject: [PATCH] move pytroch_pretrained_bert cache folder under same path as torch --- hubconf.py | 2 +- pytorch_pretrained_bert/file_utils.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/hubconf.py b/hubconf.py index 755e181d20..193c018ee0 100644 --- a/hubconf.py +++ b/hubconf.py @@ -84,7 +84,7 @@ def bertTokenizer(*args, **kwargs): Example: >>> sentence = 'Hello, World!' - >>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False) + >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False) >>> toks = tokenizer.tokenize(sentence) ['Hello', '##,', 'World', '##!'] >>> ids = tokenizer.convert_tokens_to_ids(toks) diff --git a/pytorch_pretrained_bert/file_utils.py b/pytorch_pretrained_bert/file_utils.py index 17bdd258ea..605c841235 100644 --- a/pytorch_pretrained_bert/file_utils.py +++ b/pytorch_pretrained_bert/file_utils.py @@ -22,6 +22,15 @@ import requests from botocore.exceptions import ClientError from tqdm import tqdm +try: + from torch.hub import _get_torch_home + torch_cache_home = _get_torch_home() +except ImportError: + torch_cache_home = os.path.expanduser( + os.getenv('TORCH_HOME', os.path.join( + os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) +default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert') + try: from urllib.parse import urlparse except ImportError: @@ -29,11 +38,11 @@ except ImportError: try: from pathlib import Path - PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', - Path.home() / '.pytorch_pretrained_bert')) + PYTORCH_PRETRAINED_BERT_CACHE = Path( + os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) except (AttributeError, ImportError): PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', - os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) + default_cache_path) CONFIG_NAME = "config.json" WEIGHTS_NAME = "pytorch_model.bin"