Add utility function for retrieving locally cached models (#8836)

* add get_cached_models function

* add List type to import

* fix code quality

* Update src/transformers/file_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/file_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/file_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/file_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/file_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix style

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Charles
2021-01-04 14:53:54 +00:00
committed by GitHub
parent 8eb7f26d5d
commit c581d8af7a

View File

@@ -32,7 +32,7 @@ from dataclasses import fields
from functools import partial, wraps
from hashlib import sha256
from pathlib import Path
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
@@ -1030,6 +1030,39 @@ def filename_to_url(filename, cache_dir=None):
return url, etag
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape
:obj:`(model_url, etag, size_MB)`. Filenames in :obj:`cache_dir` are use to get the metadata for each model, only
urls ending with `.bin` are added.
Args:
cache_dir (:obj:`Union[str, Path]`, `optional`):
The cache directory to search for models within. Will default to the transformers cache if unset.
Returns:
List[Tuple]: List of tuples each with shape :obj:`(model_url, etag, size_MB)`
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
elif isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cached_models = []
for file in os.listdir(cache_dir):
if file.endswith(".json"):
meta_path = os.path.join(cache_dir, file)
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
if url.endswith(".bin"):
size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
cached_models.append((url, etag, size_MB))
return cached_models
def cached_path(
url_or_filename,
cache_dir=None,