diff --git a/pytorch_pretrained_bert/file_utils.py b/pytorch_pretrained_bert/file_utils.py index 0b5fc2c217..6954bec0e1 100644 --- a/pytorch_pretrained_bert/file_utils.py +++ b/pytorch_pretrained_bert/file_utils.py @@ -12,6 +12,7 @@ import shutil import tempfile from functools import wraps from hashlib import sha256 +import sys from io import open import boto3 @@ -60,6 +61,8 @@ def filename_to_url(filename, cache_dir=None): """ if cache_dir is None: cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) cache_path = os.path.join(cache_dir, filename) if not os.path.exists(cache_path): @@ -86,6 +89,10 @@ def cached_path(url_or_filename, cache_dir=None): """ if cache_dir is None: cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) parsed = urlparse(url_or_filename) @@ -171,6 +178,8 @@ def get_from_cache(url, cache_dir=None): """ if cache_dir is None: cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) if not os.path.exists(cache_dir): os.makedirs(cache_dir)