Simplify cache vars and allow for TRANSFORMERS_CACHE env (#4226)
* simplify cache vars and allow for TRANSFORMERS_CACHE env As it currently stands, "TRANSFORMERS_CACHE" is not an accepted variable. It seems that the these variables were not updated when moving from version pytorch_transformers to transformers. In addition, the fallback procedure could be improved. and simplified. Pathlib seems redundant here. * Update file_utils.py
This commit is contained in:
@@ -15,6 +15,7 @@ import tempfile
|
||||
from contextlib import contextmanager
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
@@ -68,19 +69,10 @@ except ImportError:
|
||||
)
|
||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
||||
os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
|
||||
)
|
||||
except (AttributeError, ImportError):
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
|
||||
"PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||
)
|
||||
|
||||
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
||||
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
||||
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
||||
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||
|
||||
Reference in New Issue
Block a user