Support for private models from huggingface.co (#9141)
* minor wording tweaks * Create private model repo + exist_ok flag * file_utils: `use_auth_token` * Update src/transformers/file_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Propagate doc from @sgugger Co-Authored-By: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -317,6 +317,9 @@ class PretrainedConfig(object):
|
|||||||
proxies (:obj:`Dict[str, str]`, `optional`):
|
proxies (:obj:`Dict[str, str]`, `optional`):
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||||
|
use_auth_token (:obj:`str` or `bool`, `optional`):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
|
||||||
|
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
|
||||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
@@ -332,6 +335,10 @@ class PretrainedConfig(object):
|
|||||||
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
||||||
by the ``return_unused_kwargs`` keyword parameter.
|
by the ``return_unused_kwargs`` keyword parameter.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
|
:class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
|
||||||
|
|
||||||
@@ -373,6 +380,7 @@ class PretrainedConfig(object):
|
|||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
|
|
||||||
@@ -395,6 +403,7 @@ class PretrainedConfig(object):
|
|||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
# Load config dict
|
# Load config dict
|
||||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ Utilities for working with the local dataset cache. Parts of this file is adapte
|
|||||||
https://github.com/allenai/allennlp.
|
https://github.com/allenai/allennlp.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
@@ -42,6 +43,7 @@ import requests
|
|||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
|
from .hf_api import HfFolder
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -1024,6 +1026,7 @@ def cached_path(
|
|||||||
user_agent: Union[Dict, str, None] = None,
|
user_agent: Union[Dict, str, None] = None,
|
||||||
extract_compressed_file=False,
|
extract_compressed_file=False,
|
||||||
force_extract=False,
|
force_extract=False,
|
||||||
|
use_auth_token: Union[bool, str, None] = None,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -1036,6 +1039,8 @@ def cached_path(
|
|||||||
force_download: if True, re-download the file even if it's already cached in the cache dir.
|
force_download: if True, re-download the file even if it's already cached in the cache dir.
|
||||||
resume_download: if True, resume the download if incompletely received file is found.
|
resume_download: if True, resume the download if incompletely received file is found.
|
||||||
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
||||||
|
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
|
||||||
|
will get token from ~/.huggingface.
|
||||||
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
|
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
|
||||||
file in a folder along the archive.
|
file in a folder along the archive.
|
||||||
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
|
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
|
||||||
@@ -1063,6 +1068,7 @@ def cached_path(
|
|||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
elif os.path.exists(url_or_filename):
|
elif os.path.exists(url_or_filename):
|
||||||
@@ -1125,11 +1131,11 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|||||||
return ua
|
return ua
|
||||||
|
|
||||||
|
|
||||||
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
|
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
||||||
"""
|
"""
|
||||||
Donwload remote file. Do not gobble up errors.
|
Donwload remote file. Do not gobble up errors.
|
||||||
"""
|
"""
|
||||||
headers = {"user-agent": http_user_agent(user_agent)}
|
headers = copy.deepcopy(headers)
|
||||||
if resume_size > 0:
|
if resume_size > 0:
|
||||||
headers["Range"] = "bytes=%d-" % (resume_size,)
|
headers["Range"] = "bytes=%d-" % (resume_size,)
|
||||||
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||||
@@ -1159,6 +1165,7 @@ def get_from_cache(
|
|||||||
etag_timeout=10,
|
etag_timeout=10,
|
||||||
resume_download=False,
|
resume_download=False,
|
||||||
user_agent: Union[Dict, str, None] = None,
|
user_agent: Union[Dict, str, None] = None,
|
||||||
|
use_auth_token: Union[bool, str, None] = None,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -1178,11 +1185,19 @@ def get_from_cache(
|
|||||||
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
headers = {"user-agent": http_user_agent(user_agent)}
|
||||||
|
if isinstance(use_auth_token, str):
|
||||||
|
headers["authorization"] = "Bearer {}".format(use_auth_token)
|
||||||
|
elif use_auth_token:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||||
|
headers["authorization"] = "Bearer {}".format(token)
|
||||||
|
|
||||||
url_to_download = url
|
url_to_download = url
|
||||||
etag = None
|
etag = None
|
||||||
if not local_files_only:
|
if not local_files_only:
|
||||||
try:
|
try:
|
||||||
headers = {"user-agent": http_user_agent(user_agent)}
|
|
||||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
||||||
@@ -1272,7 +1287,7 @@ def get_from_cache(
|
|||||||
with temp_file_manager() as temp_file:
|
with temp_file_manager() as temp_file:
|
||||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||||
|
|
||||||
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
|
||||||
|
|
||||||
logger.info("storing %s in cache at %s", url, cache_path)
|
logger.info("storing %s in cache at %s", url, cache_path)
|
||||||
os.replace(temp_file.name, cache_path)
|
os.replace(temp_file.name, cache_path)
|
||||||
|
|||||||
@@ -206,7 +206,7 @@ class HfApi:
|
|||||||
|
|
||||||
def model_list(self) -> List[ModelInfo]:
|
def model_list(self) -> List[ModelInfo]:
|
||||||
"""
|
"""
|
||||||
Get the public list of all the models on huggingface, including the community models
|
Get the public list of all the models on huggingface.co
|
||||||
"""
|
"""
|
||||||
path = "{}/api/models".format(self.endpoint)
|
path = "{}/api/models".format(self.endpoint)
|
||||||
r = requests.get(path)
|
r = requests.get(path)
|
||||||
@@ -228,7 +228,13 @@ class HfApi:
|
|||||||
return [RepoObj(**x) for x in d]
|
return [RepoObj(**x) for x in d]
|
||||||
|
|
||||||
def create_repo(
|
def create_repo(
|
||||||
self, token: str, name: str, organization: Optional[str] = None, lfsmultipartthresh: Optional[int] = None
|
self,
|
||||||
|
token: str,
|
||||||
|
name: str,
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
private: Optional[bool] = None,
|
||||||
|
exist_ok=False,
|
||||||
|
lfsmultipartthresh: Optional[int] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
HuggingFace git-based system, used for models.
|
HuggingFace git-based system, used for models.
|
||||||
@@ -236,10 +242,14 @@ class HfApi:
|
|||||||
Call HF API to create a whole repo.
|
Call HF API to create a whole repo.
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
|
private: Whether the model repo should be private (requires a paid huggingface.co account)
|
||||||
|
|
||||||
|
exist_ok: Do not raise an error if repo already exists
|
||||||
|
|
||||||
lfsmultipartthresh: Optional: internal param for testing purposes.
|
lfsmultipartthresh: Optional: internal param for testing purposes.
|
||||||
"""
|
"""
|
||||||
path = "{}/api/repos/create".format(self.endpoint)
|
path = "{}/api/repos/create".format(self.endpoint)
|
||||||
json = {"name": name, "organization": organization}
|
json = {"name": name, "organization": organization, "private": private}
|
||||||
if lfsmultipartthresh is not None:
|
if lfsmultipartthresh is not None:
|
||||||
json["lfsmultipartthresh"] = lfsmultipartthresh
|
json["lfsmultipartthresh"] = lfsmultipartthresh
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
@@ -247,6 +257,8 @@ class HfApi:
|
|||||||
headers={"authorization": "Bearer {}".format(token)},
|
headers={"authorization": "Bearer {}".format(token)},
|
||||||
json=json,
|
json=json,
|
||||||
)
|
)
|
||||||
|
if exist_ok and r.status_code == 409:
|
||||||
|
return ""
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
d = r.json()
|
d = r.json()
|
||||||
return d["url"]
|
return d["url"]
|
||||||
|
|||||||
@@ -226,6 +226,7 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# Load config if we don't provide a configuration
|
||||||
@@ -240,6 +241,7 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -283,6 +285,7 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
except EnvironmentError as err:
|
except EnvironmentError as err:
|
||||||
logger.error(err)
|
logger.error(err)
|
||||||
|
|||||||
@@ -894,6 +894,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to only look at local files (e.g., not try doanloading the model).
|
Whether or not to only look at local files (e.g., not try doanloading the model).
|
||||||
|
use_auth_token (:obj:`str` or `bool`, `optional`):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
|
||||||
|
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
|
||||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
@@ -916,6 +919,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||||
attribute will be passed to the underlying model's ``__init__`` function.
|
attribute will be passed to the underlying model's ``__init__`` function.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
>>> from transformers import BertConfig, TFBertModel
|
>>> from transformers import BertConfig, TFBertModel
|
||||||
@@ -939,6 +946,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
mirror = kwargs.pop("mirror", None)
|
mirror = kwargs.pop("mirror", None)
|
||||||
|
|
||||||
@@ -954,6 +962,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -996,6 +1005,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
except EnvironmentError as err:
|
except EnvironmentError as err:
|
||||||
logger.error(err)
|
logger.error(err)
|
||||||
|
|||||||
@@ -886,6 +886,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||||
|
use_auth_token (:obj:`str` or `bool`, `optional`):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
|
||||||
|
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
|
||||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
@@ -908,6 +911,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||||
attribute will be passed to the underlying model's ``__init__`` function.
|
attribute will be passed to the underlying model's ``__init__`` function.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
>>> from transformers import BertConfig, BertModel
|
>>> from transformers import BertConfig, BertModel
|
||||||
@@ -931,6 +938,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
mirror = kwargs.pop("mirror", None)
|
mirror = kwargs.pop("mirror", None)
|
||||||
|
|
||||||
@@ -946,6 +954,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -998,6 +1007,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
except EnvironmentError as err:
|
except EnvironmentError as err:
|
||||||
logger.error(err)
|
logger.error(err)
|
||||||
|
|||||||
@@ -744,8 +744,8 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
task identifier: :obj:`"text-generation"`.
|
task identifier: :obj:`"text-generation"`.
|
||||||
|
|
||||||
The models that this pipeline can use are models that have been trained with an autoregressive language modeling
|
The models that this pipeline can use are models that have been trained with an autoregressive language modeling
|
||||||
objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available
|
objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available models
|
||||||
community models on `huggingface.co/models <https://huggingface.co/models?filter=causal-lm>`__.
|
on `huggingface.co/models <https://huggingface.co/models?filter=causal-lm>`__.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||||
|
|||||||
@@ -1648,6 +1648,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
proxies (:obj:`Dict[str, str], `optional`):
|
proxies (:obj:`Dict[str, str], `optional`):
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
use_auth_token (:obj:`str` or `bool`, `optional`):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
|
||||||
|
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
|
||||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
@@ -1662,6 +1665,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
|
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
|
||||||
``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__`` for more details.
|
``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__`` for more details.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
# We can't instantiate directly the base class `PreTrainedTokenizerBase` so let's show our examples on a derived class: BertTokenizer
|
# We can't instantiate directly the base class `PreTrainedTokenizerBase` so let's show our examples on a derived class: BertTokenizer
|
||||||
@@ -1689,6 +1696,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
subfolder = kwargs.pop("subfolder", None)
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
|
|
||||||
@@ -1770,6 +1778,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
except requests.exceptions.HTTPError as err:
|
except requests.exceptions.HTTPError as err:
|
||||||
if "404 Client Error" in str(err):
|
if "404 Client Error" in str(err):
|
||||||
|
|||||||
Reference in New Issue
Block a user