From 61d22f9cc7fbb0556d1e14447957a7ac6441421b Mon Sep 17 00:00:00 2001 From: Bram Vanroy Date: Mon, 11 May 2020 21:24:02 +0200 Subject: [PATCH] Simplify cache vars and allow for TRANSFORMERS_CACHE env (#4226) * simplify cache vars and allow for TRANSFORMERS_CACHE env As it currently stands, "TRANSFORMERS_CACHE" is not an accepted variable. It seems that the these variables were not updated when moving from version pytorch_transformers to transformers. In addition, the fallback procedure could be improved. and simplified. Pathlib seems redundant here. * Update file_utils.py --- src/transformers/file_utils.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index cacd15a369..d5abb77aa8 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -15,6 +15,7 @@ import tempfile from contextlib import contextmanager from functools import partial, wraps from hashlib import sha256 +from pathlib import Path from typing import Optional from urllib.parse import urlparse from zipfile import ZipFile, is_zipfile @@ -68,19 +69,10 @@ except ImportError: ) default_cache_path = os.path.join(torch_cache_home, "transformers") -try: - from pathlib import Path - PYTORCH_PRETRAINED_BERT_CACHE = Path( - os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)) - ) -except (AttributeError, ImportError): - PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( - "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) - ) - -PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility -TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility +PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) +PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) +TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) WEIGHTS_NAME = "pytorch_model.bin" TF2_WEIGHTS_NAME = "tf_model.h5"