Simplify cache vars and allow for TRANSFORMERS_CACHE env (#4226)
* simplify cache vars and allow for TRANSFORMERS_CACHE env As it currently stands, "TRANSFORMERS_CACHE" is not an accepted variable. It seems that the these variables were not updated when moving from version pytorch_transformers to transformers. In addition, the fallback procedure could be improved. and simplified. Pathlib seems redundant here. * Update file_utils.py
This commit is contained in:
@@ -15,6 +15,7 @@ import tempfile
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from zipfile import ZipFile, is_zipfile
|
from zipfile import ZipFile, is_zipfile
|
||||||
@@ -68,19 +69,10 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||||
|
|
||||||
try:
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||||
os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
|
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
||||||
)
|
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
||||||
except (AttributeError, ImportError):
|
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
|
|
||||||
"PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
|
||||||
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
|
||||||
|
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||||
|
|||||||
Reference in New Issue
Block a user