Add utility function for retrieving locally cached models (#8836)
* add get_cached_models function * add List type to import * fix code quality * Update src/transformers/file_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/file_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/file_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/file_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/file_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -32,7 +32,7 @@ from dataclasses import fields
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
|
||||
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
@@ -1030,6 +1030,39 @@ def filename_to_url(filename, cache_dir=None):
|
||||
return url, etag
|
||||
|
||||
|
||||
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
||||
"""
|
||||
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape
|
||||
:obj:`(model_url, etag, size_MB)`. Filenames in :obj:`cache_dir` are use to get the metadata for each model, only
|
||||
urls ending with `.bin` are added.
|
||||
|
||||
Args:
|
||||
cache_dir (:obj:`Union[str, Path]`, `optional`):
|
||||
The cache directory to search for models within. Will default to the transformers cache if unset.
|
||||
|
||||
Returns:
|
||||
List[Tuple]: List of tuples each with shape :obj:`(model_url, etag, size_MB)`
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
elif isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
cached_models = []
|
||||
for file in os.listdir(cache_dir):
|
||||
if file.endswith(".json"):
|
||||
meta_path = os.path.join(cache_dir, file)
|
||||
with open(meta_path, encoding="utf-8") as meta_file:
|
||||
metadata = json.load(meta_file)
|
||||
url = metadata["url"]
|
||||
etag = metadata["etag"]
|
||||
if url.endswith(".bin"):
|
||||
size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
|
||||
cached_models.append((url, etag, size_MB))
|
||||
|
||||
return cached_models
|
||||
|
||||
|
||||
def cached_path(
|
||||
url_or_filename,
|
||||
cache_dir=None,
|
||||
|
||||
Reference in New Issue
Block a user