rm boto3 dependency
This commit is contained in:
@@ -19,10 +19,7 @@ from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import ClientError
|
||||
from filelock import FileLock
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@@ -144,7 +141,7 @@ def add_end_docstrings(*docstr):
|
||||
|
||||
def is_remote_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ("http", "https", "s3")
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
|
||||
@@ -297,55 +294,6 @@ def cached_path(
|
||||
return output_path
|
||||
|
||||
|
||||
def split_s3_path(url):
|
||||
"""Split a full s3 path into the bucket name and path."""
|
||||
parsed = urlparse(url)
|
||||
if not parsed.netloc or not parsed.path:
|
||||
raise ValueError("bad s3 path {}".format(url))
|
||||
bucket_name = parsed.netloc
|
||||
s3_path = parsed.path
|
||||
# Remove '/' at beginning of path.
|
||||
if s3_path.startswith("/"):
|
||||
s3_path = s3_path[1:]
|
||||
return bucket_name, s3_path
|
||||
|
||||
|
||||
def s3_request(func):
|
||||
"""
|
||||
Wrapper function for s3 requests in order to create more helpful error
|
||||
messages.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(url, *args, **kwargs):
|
||||
try:
|
||||
return func(url, *args, **kwargs)
|
||||
except ClientError as exc:
|
||||
if int(exc.response["Error"]["Code"]) == 404:
|
||||
raise EnvironmentError("file {} not found".format(url))
|
||||
else:
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@s3_request
|
||||
def s3_etag(url, proxies=None):
|
||||
"""Check ETag on S3 object."""
|
||||
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
||||
bucket_name, s3_path = split_s3_path(url)
|
||||
s3_object = s3_resource.Object(bucket_name, s3_path)
|
||||
return s3_object.e_tag
|
||||
|
||||
|
||||
@s3_request
|
||||
def s3_get(url, temp_file, proxies=None):
|
||||
"""Pull a file directly from S3."""
|
||||
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
||||
bucket_name, s3_path = split_s3_path(url)
|
||||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
||||
|
||||
|
||||
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||
if is_torch_available():
|
||||
@@ -406,17 +354,13 @@ def get_from_cache(
|
||||
|
||||
etag = None
|
||||
if not local_files_only:
|
||||
# Get eTag to add to filename, if it exists.
|
||||
if url.startswith("s3://"):
|
||||
etag = s3_etag(url, proxies=proxies)
|
||||
else:
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
||||
if response.status_code == 200:
|
||||
etag = response.headers.get("ETag")
|
||||
except (EnvironmentError, requests.exceptions.Timeout):
|
||||
# etag is already None
|
||||
pass
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
||||
if response.status_code == 200:
|
||||
etag = response.headers.get("ETag")
|
||||
except (EnvironmentError, requests.exceptions.Timeout):
|
||||
# etag is already None
|
||||
pass
|
||||
|
||||
filename = url_to_filename(url, etag)
|
||||
|
||||
@@ -483,13 +427,7 @@ def get_from_cache(
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
if resume_download:
|
||||
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
||||
s3_get(url, temp_file, proxies=proxies)
|
||||
else:
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
|
||||
logger.info("storing %s in cache at %s", url, cache_path)
|
||||
os.replace(temp_file.name, cache_path)
|
||||
|
||||
Reference in New Issue
Block a user