From c581d8af7a81e7e5933ef5971d6905d1c8528bc0 Mon Sep 17 00:00:00 2001 From: Charles Date: Mon, 4 Jan 2021 14:53:54 +0000 Subject: [PATCH] Add utility function for retrieving locally cached models (#8836) * add get_cached_models function * add List type to import * fix code quality * Update src/transformers/file_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/file_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/file_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/file_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/file_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/file_utils.py | 35 +++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 87f329b729..546d0e819b 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -32,7 +32,7 @@ from dataclasses import fields from functools import partial, wraps from hashlib import sha256 from pathlib import Path -from typing import Any, BinaryIO, Dict, Optional, Tuple, Union +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse from zipfile import ZipFile, is_zipfile @@ -1030,6 +1030,39 @@ def filename_to_url(filename, cache_dir=None): return url, etag +def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]: + """ + Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape + :obj:`(model_url, etag, size_MB)`. Filenames in :obj:`cache_dir` are use to get the metadata for each model, only + urls ending with `.bin` are added. + + Args: + cache_dir (:obj:`Union[str, Path]`, `optional`): + The cache directory to search for models within. Will default to the transformers cache if unset. + + Returns: + List[Tuple]: List of tuples each with shape :obj:`(model_url, etag, size_MB)` + """ + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + elif isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cached_models = [] + for file in os.listdir(cache_dir): + if file.endswith(".json"): + meta_path = os.path.join(cache_dir, file) + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata["url"] + etag = metadata["etag"] + if url.endswith(".bin"): + size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6 + cached_models.append((url, etag, size_MB)) + + return cached_models + + def cached_path( url_or_filename, cache_dir=None,