From fb650df8590f796663226132482d09da5b0fb613 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 16 Dec 2020 16:09:57 +0100 Subject: [PATCH] Support for private models from huggingface.co (#9141) * minor wording tweaks * Create private model repo + exist_ok flag * file_utils: `use_auth_token` * Update src/transformers/file_utils.py Co-authored-by: Patrick von Platen * Propagate doc from @sgugger Co-Authored-By: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/configuration_utils.py | 9 ++++++++ src/transformers/file_utils.py | 23 +++++++++++++++++---- src/transformers/hf_api.py | 18 +++++++++++++--- src/transformers/modeling_flax_utils.py | 3 +++ src/transformers/modeling_tf_utils.py | 10 +++++++++ src/transformers/modeling_utils.py | 10 +++++++++ src/transformers/pipelines.py | 4 ++-- src/transformers/tokenization_utils_base.py | 9 ++++++++ 8 files changed, 77 insertions(+), 9 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 2825e7efa5..d85e62d269 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -317,6 +317,9 @@ class PretrainedConfig(object): proxies (:obj:`Dict[str, str]`, `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any @@ -332,6 +335,10 @@ class PretrainedConfig(object): values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the ``return_unused_kwargs`` keyword parameter. + .. note:: + + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + Returns: :class:`PretrainedConfig`: The configuration object instantiated from this pretrained model. @@ -373,6 +380,7 @@ class PretrainedConfig(object): force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) @@ -395,6 +403,7 @@ class PretrainedConfig(object): proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, + use_auth_token=use_auth_token, ) # Load config dict config_dict = cls._dict_from_json_file(resolved_config_file) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 420a6ff21a..27e28cbfa3 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -16,6 +16,7 @@ Utilities for working with the local dataset cache. Parts of this file is adapte https://github.com/allenai/allennlp. """ +import copy import fnmatch import io import json @@ -42,6 +43,7 @@ import requests from filelock import FileLock from . import __version__ +from .hf_api import HfFolder from .utils import logging @@ -1024,6 +1026,7 @@ def cached_path( user_agent: Union[Dict, str, None] = None, extract_compressed_file=False, force_extract=False, + use_auth_token: Union[bool, str, None] = None, local_files_only=False, ) -> Optional[str]: """ @@ -1036,6 +1039,8 @@ def cached_path( force_download: if True, re-download the file even if it's already cached in the cache dir. resume_download: if True, resume the download if incompletely received file is found. user_agent: Optional string or dict that will be appended to the user-agent on remote requests. + use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True, + will get token from ~/.huggingface. extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed file in a folder along the archive. force_extract: if True when extract_compressed_file is True and the archive was already extracted, @@ -1063,6 +1068,7 @@ def cached_path( proxies=proxies, resume_download=resume_download, user_agent=user_agent, + use_auth_token=use_auth_token, local_files_only=local_files_only, ) elif os.path.exists(url_or_filename): @@ -1125,11 +1131,11 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: return ua -def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None): +def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None): """ Donwload remote file. Do not gobble up errors. """ - headers = {"user-agent": http_user_agent(user_agent)} + headers = copy.deepcopy(headers) if resume_size > 0: headers["Range"] = "bytes=%d-" % (resume_size,) r = requests.get(url, stream=True, proxies=proxies, headers=headers) @@ -1159,6 +1165,7 @@ def get_from_cache( etag_timeout=10, resume_download=False, user_agent: Union[Dict, str, None] = None, + use_auth_token: Union[bool, str, None] = None, local_files_only=False, ) -> Optional[str]: """ @@ -1178,11 +1185,19 @@ def get_from_cache( os.makedirs(cache_dir, exist_ok=True) + headers = {"user-agent": http_user_agent(user_agent)} + if isinstance(use_auth_token, str): + headers["authorization"] = "Bearer {}".format(use_auth_token) + elif use_auth_token: + token = HfFolder.get_token() + if token is None: + raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") + headers["authorization"] = "Bearer {}".format(token) + url_to_download = url etag = None if not local_files_only: try: - headers = {"user-agent": http_user_agent(user_agent)} r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout) r.raise_for_status() etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") @@ -1272,7 +1287,7 @@ def get_from_cache( with temp_file_manager() as temp_file: logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) - http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) + http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers) logger.info("storing %s in cache at %s", url, cache_path) os.replace(temp_file.name, cache_path) diff --git a/src/transformers/hf_api.py b/src/transformers/hf_api.py index 59578a2a79..c95a992039 100644 --- a/src/transformers/hf_api.py +++ b/src/transformers/hf_api.py @@ -206,7 +206,7 @@ class HfApi: def model_list(self) -> List[ModelInfo]: """ - Get the public list of all the models on huggingface, including the community models + Get the public list of all the models on huggingface.co """ path = "{}/api/models".format(self.endpoint) r = requests.get(path) @@ -228,7 +228,13 @@ class HfApi: return [RepoObj(**x) for x in d] def create_repo( - self, token: str, name: str, organization: Optional[str] = None, lfsmultipartthresh: Optional[int] = None + self, + token: str, + name: str, + organization: Optional[str] = None, + private: Optional[bool] = None, + exist_ok=False, + lfsmultipartthresh: Optional[int] = None, ) -> str: """ HuggingFace git-based system, used for models. @@ -236,10 +242,14 @@ class HfApi: Call HF API to create a whole repo. Params: + private: Whether the model repo should be private (requires a paid huggingface.co account) + + exist_ok: Do not raise an error if repo already exists + lfsmultipartthresh: Optional: internal param for testing purposes. """ path = "{}/api/repos/create".format(self.endpoint) - json = {"name": name, "organization": organization} + json = {"name": name, "organization": organization, "private": private} if lfsmultipartthresh is not None: json["lfsmultipartthresh"] = lfsmultipartthresh r = requests.post( @@ -247,6 +257,8 @@ class HfApi: headers={"authorization": "Bearer {}".format(token)}, json=json, ) + if exist_ok and r.status_code == 409: + return "" r.raise_for_status() d = r.json() return d["url"] diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 0823ddf1bc..4a3b5a95b3 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -226,6 +226,7 @@ class FlaxPreTrainedModel(ABC): resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) # Load config if we don't provide a configuration @@ -240,6 +241,7 @@ class FlaxPreTrainedModel(ABC): resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, + use_auth_token=use_auth_token, revision=revision, **kwargs, ) @@ -283,6 +285,7 @@ class FlaxPreTrainedModel(ABC): proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, + use_auth_token=use_auth_token, ) except EnvironmentError as err: logger.error(err) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index b15bbe0690..9b78555c18 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -894,6 +894,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to only look at local files (e.g., not try doanloading the model). + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any @@ -916,6 +919,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + .. note:: + + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + Examples:: >>> from transformers import BertConfig, TFBertModel @@ -939,6 +946,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) mirror = kwargs.pop("mirror", None) @@ -954,6 +962,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, + use_auth_token=use_auth_token, revision=revision, **kwargs, ) @@ -996,6 +1005,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, + use_auth_token=use_auth_token, ) except EnvironmentError as err: logger.error(err) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index be0056f9c0..fba7aa89cb 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -886,6 +886,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any @@ -908,6 +911,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + .. note:: + + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + Examples:: >>> from transformers import BertConfig, BertModel @@ -931,6 +938,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) mirror = kwargs.pop("mirror", None) @@ -946,6 +954,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, + use_auth_token=use_auth_token, revision=revision, **kwargs, ) @@ -998,6 +1007,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, + use_auth_token=use_auth_token, ) except EnvironmentError as err: logger.error(err) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index f328aba117..224f5f3ac0 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -744,8 +744,8 @@ class TextGenerationPipeline(Pipeline): task identifier: :obj:`"text-generation"`. The models that this pipeline can use are models that have been trained with an autoregressive language modeling - objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available - community models on `huggingface.co/models `__. + objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available models + on `huggingface.co/models `__. """ # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 2261773c28..beff5ed2f4 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1648,6 +1648,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): proxies (:obj:`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any @@ -1662,6 +1665,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__`` for more details. + .. note:: + + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + Examples:: # We can't instantiate directly the base class `PreTrainedTokenizerBase` so let's show our examples on a derived class: BertTokenizer @@ -1689,6 +1696,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) @@ -1770,6 +1778,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, + use_auth_token=use_auth_token, ) except requests.exceptions.HTTPError as err: if "404 Client Error" in str(err):