python 2 compatibility

This commit is contained in:
thomwolf
2019-02-06 00:07:46 +01:00
parent ba37ddc5ce
commit 448937c00d
17 changed files with 246 additions and 184 deletions

View File

@@ -3,31 +3,39 @@ Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
from __future__ import (absolute_import, division, print_function, unicode_literals)
import os
import json
import logging
import os
import shutil
import tempfile
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set
from hashlib import sha256
from functools import wraps
from tqdm import tqdm
from hashlib import sha256
from io import open
import boto3
from botocore.exceptions import ClientError
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert'))
except ImportError:
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert'))
def url_to_filename(url: str, etag: str = None) -> str:
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
@@ -45,25 +53,23 @@ def url_to_filename(url: str, etag: str = None) -> str:
return filename
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise FileNotFoundError("file {} not found".format(cache_path))
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise FileNotFoundError("file {} not found".format(meta_path))
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path) as meta_file:
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
@@ -71,7 +77,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[
return url, etag
def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
def cached_path(url_or_filename, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
@@ -80,10 +86,6 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
@@ -95,13 +97,13 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise FileNotFoundError("file {} not found".format(url_or_filename))
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url: str) -> Tuple[str, str]:
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
@@ -114,19 +116,19 @@ def split_s3_path(url: str) -> Tuple[str, str]:
return bucket_name, s3_path
def s3_request(func: Callable):
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url: str, *args, **kwargs):
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise FileNotFoundError("file {} not found".format(url))
raise EnvironmentError("file {} not found".format(url))
else:
raise
@@ -134,7 +136,7 @@ def s3_request(func: Callable):
@s3_request
def s3_etag(url: str) -> Optional[str]:
def s3_etag(url):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
@@ -143,14 +145,14 @@ def s3_etag(url: str) -> Optional[str]:
@s3_request
def s3_get(url: str, temp_file: IO) -> None:
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url: str, temp_file: IO) -> None:
def http_get(url, temp_file):
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
@@ -162,17 +164,16 @@ def http_get(url: str, temp_file: IO) -> None:
progress.close()
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
@@ -213,7 +214,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file:
with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file)
logger.info("removing temp file %s", temp_file.name)
@@ -221,7 +222,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
return cache_path
def read_set_from_file(filename: str) -> Set[str]:
def read_set_from_file(filename):
'''
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
@@ -233,7 +234,7 @@ def read_set_from_file(filename: str) -> Set[str]:
return collection
def get_file_extension(path: str, dot=True, lower: bool = True):
def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext