Model versioning (#8324)

* fix typo

* rm use_cdn & references, and implement new hf_bucket_url

* I'm pretty sure we don't need to `read` this file

* same here

* [BIG] file_utils.networking: do not gobble up errors anymore

* Fix CI 😇

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Tiny doc tweak

* Add doc + pass kwarg everywhere

* Add more tests and explain

cc @sshleifer let me know if better

Co-Authored-By: Sam Shleifer <sshleifer@gmail.com>

* Also implement revision in pipelines

In the case where we're passing a task name or a string model identifier

* Fix CI 😇

* Fix CI

* [hf_api] new methods + command line implem

* make style

* Final endpoints post-migration

* Fix post-migration

* Py3.6 compat

cc @stefan-it

Thank you @stas00

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Julien Chaumond
2020-11-10 13:11:02 +01:00
committed by GitHub
parent 4185b115d4
commit 70f622fab4
23 changed files with 472 additions and 210 deletions

View File

@@ -311,6 +311,10 @@ class PretrainedConfig(object):
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final configuration object.
@@ -362,6 +366,7 @@ class PretrainedConfig(object):
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
@@ -369,7 +374,7 @@ class PretrainedConfig(object):
config_file = pretrained_model_name_or_path
else:
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
)
try:
@@ -383,11 +388,10 @@ class PretrainedConfig(object):
local_files_only=local_files_only,
)
# Load config dict
if resolved_config_file is None:
raise EnvironmentError
config_dict = cls._dict_from_json_file(resolved_config_file)
except EnvironmentError:
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"