Model versioning (#8324)
* fix typo * rm use_cdn & references, and implement new hf_bucket_url * I'm pretty sure we don't need to `read` this file * same here * [BIG] file_utils.networking: do not gobble up errors anymore * Fix CI 😇 * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Tiny doc tweak * Add doc + pass kwarg everywhere * Add more tests and explain cc @sshleifer let me know if better Co-Authored-By: Sam Shleifer <sshleifer@gmail.com> * Also implement revision in pipelines In the case where we're passing a task name or a string model identifier * Fix CI 😇 * Fix CI * [hf_api] new methods + command line implem * make style * Final endpoints post-migration * Fix post-migration * Py3.6 compat cc @stefan-it Thank you @stas00 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -813,9 +813,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (e.g., not try doanloading the model).
|
||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
@@ -857,7 +858,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_cdn = kwargs.pop("use_cdn", True)
|
||||
revision = kwargs.pop("revision", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
@@ -872,6 +873,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -909,7 +911,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
|
||||
use_cdn=use_cdn,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
|
||||
@@ -923,9 +925,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError
|
||||
except EnvironmentError:
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
|
||||
Reference in New Issue
Block a user