back compatibility with Path inputs in fle_utils

This commit is contained in:
thomwolf
2019-02-09 17:05:23 +01:00
parent 9f9909ea2f
commit f0bf81e141

View File

@@ -12,6 +12,7 @@ import shutil
import tempfile import tempfile
from functools import wraps from functools import wraps
from hashlib import sha256 from hashlib import sha256
import sys
from io import open from io import open
import boto3 import boto3
@@ -60,6 +61,8 @@ def filename_to_url(filename, cache_dir=None):
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename) cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
@@ -86,6 +89,10 @@ def cached_path(url_or_filename, cache_dir=None):
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename) parsed = urlparse(url_or_filename)
@@ -171,6 +178,8 @@ def get_from_cache(url, cache_dir=None):
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir): if not os.path.exists(cache_dir):
os.makedirs(cache_dir) os.makedirs(cache_dir)