Simplify cache vars and allow for TRANSFORMERS_CACHE env (#4226)

* simplify cache vars and allow for TRANSFORMERS_CACHE env

As it currently stands, "TRANSFORMERS_CACHE" is not an accepted variable. It seems that the these variables were not updated when moving from version pytorch_transformers to transformers. In addition, the fallback procedure could be improved. and simplified. Pathlib seems redundant here.

* Update file_utils.py
This commit is contained in:
Bram Vanroy
2020-05-11 21:24:02 +02:00
committed by GitHub
parent cd40cb8879
commit 61d22f9cc7

View File

@@ -15,6 +15,7 @@ import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial, wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
from pathlib import Path
from typing import Optional from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile from zipfile import ZipFile, is_zipfile
@@ -68,19 +69,10 @@ except ImportError:
) )
default_cache_path = os.path.join(torch_cache_home, "transformers") default_cache_path = os.path.join(torch_cache_home, "transformers")
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path( PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)) PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
) TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
"PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
)
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = "tf_model.h5" TF2_WEIGHTS_NAME = "tf_model.h5"