move pytroch_pretrained_bert cache folder under same path as torch
This commit is contained in:
@@ -22,6 +22,15 @@ import requests
|
||||
from botocore.exceptions import ClientError
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from torch.hub import _get_torch_home
|
||||
torch_cache_home = _get_torch_home()
|
||||
except ImportError:
|
||||
torch_cache_home = os.path.expanduser(
|
||||
os.getenv('TORCH_HOME', os.path.join(
|
||||
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
|
||||
default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert')
|
||||
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
except ImportError:
|
||||
@@ -29,11 +38,11 @@ except ImportError:
|
||||
|
||||
try:
|
||||
from pathlib import Path
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
Path.home() / '.pytorch_pretrained_bert'))
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
||||
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
|
||||
except (AttributeError, ImportError):
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
||||
default_cache_path)
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
Reference in New Issue
Block a user