add pathlib support for file_utils.py on python 3.5

This commit is contained in:
hzhwcmhf
2018-12-11 22:49:19 +08:00
parent bc659f86ad
commit 485adde742

View File

@@ -23,8 +23,8 @@ import requests
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
PYTORCH_PRETRAINED_BERT_CACHE = str(Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert'))) Path.home() / '.pytorch_pretrained_bert'))
def url_to_filename(url: str, etag: str = None) -> str: def url_to_filename(url: str, etag: str = None) -> str:
@@ -45,13 +45,15 @@ def url_to_filename(url: str, etag: str = None) -> str:
return filename return filename
def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]: def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
""" """
Return the url and etag (which may be ``None``) stored for `filename`. Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename) cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
@@ -69,7 +71,7 @@ def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
return url, etag return url, etag
def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str: def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
@@ -80,6 +82,8 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(url_or_filename, Path): if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename) url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename) parsed = urlparse(url_or_filename)
@@ -158,13 +162,15 @@ def http_get(url: str, temp_file: IO) -> None:
progress.close() progress.close()
def get_from_cache(url: str, cache_dir: str = None) -> str: def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)