move pytroch_pretrained_bert cache folder under same path as torch
This commit is contained in:
@@ -84,7 +84,7 @@ def bertTokenizer(*args, **kwargs):
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> sentence = 'Hello, World!'
|
>>> sentence = 'Hello, World!'
|
||||||
>>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
|
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
|
||||||
>>> toks = tokenizer.tokenize(sentence)
|
>>> toks = tokenizer.tokenize(sentence)
|
||||||
['Hello', '##,', 'World', '##!']
|
['Hello', '##,', 'World', '##!']
|
||||||
>>> ids = tokenizer.convert_tokens_to_ids(toks)
|
>>> ids = tokenizer.convert_tokens_to_ids(toks)
|
||||||
|
|||||||
@@ -22,6 +22,15 @@ import requests
|
|||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.hub import _get_torch_home
|
||||||
|
torch_cache_home = _get_torch_home()
|
||||||
|
except ImportError:
|
||||||
|
torch_cache_home = os.path.expanduser(
|
||||||
|
os.getenv('TORCH_HOME', os.path.join(
|
||||||
|
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
|
||||||
|
default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -29,11 +38,11 @@ except ImportError:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
||||||
Path.home() / '.pytorch_pretrained_bert'))
|
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
|
||||||
except (AttributeError, ImportError):
|
except (AttributeError, ImportError):
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||||
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
default_cache_path)
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
|||||||
Reference in New Issue
Block a user