Refine errors for pretrained objects (#15261)
* Refine errors for pretrained objects * PoC to avoid using get_list_of_files * Adapt tests to use new errors * Quality + Fix PoC * Revert "PoC to avoid using get_list_of_files" This reverts commit cb93b7cae8504ef837c2a7663cb7955e714f323e. * Revert "Quality + Fix PoC" This reverts commit 3ba6d0d4ca546708b31d355baa9e68ba9736508f. * Fix doc * Revert PoC * Add feature extractors * More tests and PT model * Adapt error message * Feature extractor tests * TF model * Flax model and test * Merge flax auto tests * Add tokenization * Fix test
This commit is contained in:
@@ -25,10 +25,15 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
|
EntryNotFoundError,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
cached_path,
|
cached_path,
|
||||||
copy_func,
|
copy_func,
|
||||||
get_list_of_files,
|
get_list_of_files,
|
||||||
@@ -520,8 +525,6 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
||||||
[`PretrainedConfig`] using `from_dict`.
|
[`PretrainedConfig`] using `from_dict`.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||||
@@ -578,30 +581,51 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
# Load config dict
|
|
||||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
|
||||||
|
|
||||||
|
except RepositoryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||||
|
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||||
|
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||||
|
"`use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except RevisionNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||||
|
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||||
|
"available revisions."
|
||||||
|
)
|
||||||
|
except EntryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
|
||||||
|
)
|
||||||
|
except HTTPError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
|
||||||
|
f"{pretrained_model_name_or_path} is not the path to a directory conaining a {configuration_file} "
|
||||||
|
"file.\nCheckout your internet connection or see how to run the library in offline mode at "
|
||||||
|
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||||
|
)
|
||||||
except EnvironmentError as err:
|
except EnvironmentError as err:
|
||||||
logger.error(err)
|
logger.error(err)
|
||||||
msg = (
|
raise EnvironmentError(
|
||||||
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
|
f"containing a {configuration_file} file"
|
||||||
)
|
)
|
||||||
|
|
||||||
if revision is not None:
|
try:
|
||||||
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
|
# Load config dict
|
||||||
|
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||||
raise EnvironmentError(msg)
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
msg = (
|
raise EnvironmentError(
|
||||||
f"Couldn't reach server at '{config_file}' to download configuration file or "
|
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
||||||
"configuration file is not a valid JSON file. "
|
|
||||||
f"Please check network or file content here: {resolved_config_file}."
|
|
||||||
)
|
)
|
||||||
raise EnvironmentError(msg)
|
|
||||||
|
|
||||||
if resolved_config_file == config_file:
|
if resolved_config_file == config_file:
|
||||||
logger.info(f"loading configuration file {config_file}")
|
logger.info(f"loading configuration file {config_file}")
|
||||||
@@ -842,9 +866,13 @@ def get_configuration_file(
|
|||||||
`str`: The configuration file to use.
|
`str`: The configuration file to use.
|
||||||
"""
|
"""
|
||||||
# Inspect all files from the repo/folder.
|
# Inspect all files from the repo/folder.
|
||||||
all_files = get_list_of_files(
|
try:
|
||||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
all_files = get_list_of_files(
|
||||||
)
|
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return FULL_CONFIGURATION_FILE
|
||||||
|
|
||||||
configuration_files_map = {}
|
configuration_files_map = {}
|
||||||
for file_name in all_files:
|
for file_name in all_files:
|
||||||
search = _re_configuration_file.search(file_name)
|
search = _re_configuration_file.search(file_name)
|
||||||
|
|||||||
@@ -24,8 +24,13 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
FEATURE_EXTRACTOR_NAME,
|
FEATURE_EXTRACTOR_NAME,
|
||||||
|
EntryNotFoundError,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
TensorType,
|
TensorType,
|
||||||
_is_jax,
|
_is_jax,
|
||||||
_is_numpy,
|
_is_numpy,
|
||||||
@@ -374,28 +379,54 @@ class FeatureExtractionMixin:
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except RepositoryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||||
|
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||||
|
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||||
|
"`use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except RevisionNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||||
|
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||||
|
"available revisions."
|
||||||
|
)
|
||||||
|
except EntryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
|
||||||
|
)
|
||||||
|
except HTTPError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
|
||||||
|
f"{pretrained_model_name_or_path} is not the path to a directory conaining a "
|
||||||
|
f"{FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run the library in "
|
||||||
|
"offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||||
|
)
|
||||||
|
except EnvironmentError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load it "
|
||||||
|
"from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||||
|
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||||
|
f"containing a {FEATURE_EXTRACTOR_NAME} file"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
# Load feature_extractor dict
|
# Load feature_extractor dict
|
||||||
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
|
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
|
||||||
text = reader.read()
|
text = reader.read()
|
||||||
feature_extractor_dict = json.loads(text)
|
feature_extractor_dict = json.loads(text)
|
||||||
|
|
||||||
except EnvironmentError as err:
|
|
||||||
logger.error(err)
|
|
||||||
msg = (
|
|
||||||
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
|
||||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
|
||||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
|
||||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {FEATURE_EXTRACTOR_NAME} file\n\n"
|
|
||||||
)
|
|
||||||
raise EnvironmentError(msg)
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
msg = (
|
raise EnvironmentError(
|
||||||
f"Couldn't reach server at '{feature_extractor_file}' to download feature extractor configuration file or "
|
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
|
||||||
"feature extractor configuration file is not a valid JSON file. "
|
|
||||||
f"Please check network or file content here: {resolved_feature_extractor_file}."
|
|
||||||
)
|
)
|
||||||
raise EnvironmentError(msg)
|
|
||||||
|
|
||||||
if resolved_feature_extractor_file == feature_extractor_file:
|
if resolved_feature_extractor_file == feature_extractor_file:
|
||||||
logger.info(f"loading feature extractor configuration file {feature_extractor_file}")
|
logger.info(f"loading feature extractor configuration file {feature_extractor_file}")
|
||||||
|
|||||||
@@ -1900,6 +1900,37 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|||||||
return ua
|
return ua
|
||||||
|
|
||||||
|
|
||||||
|
class RepositoryNotFoundError(HTTPError):
|
||||||
|
"""
|
||||||
|
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
||||||
|
not have access to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class EntryNotFoundError(HTTPError):
|
||||||
|
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
|
||||||
|
|
||||||
|
|
||||||
|
class RevisionNotFoundError(HTTPError):
|
||||||
|
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_for_status(request):
|
||||||
|
"""
|
||||||
|
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
|
||||||
|
"""
|
||||||
|
if "X-Error-Code" in request.headers:
|
||||||
|
error_code = request.headers["X-Error-Code"]
|
||||||
|
if error_code == "RepoNotFound":
|
||||||
|
raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {request.url}")
|
||||||
|
elif error_code == "EntryNotFound":
|
||||||
|
raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {request.url}")
|
||||||
|
elif error_code == "RevisionNotFound":
|
||||||
|
raise RevisionNotFoundError((f"404 Client Error: Revision Not Found for url: {request.url}"))
|
||||||
|
|
||||||
|
request.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
||||||
"""
|
"""
|
||||||
Download remote file. Do not gobble up errors.
|
Download remote file. Do not gobble up errors.
|
||||||
@@ -1908,7 +1939,7 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
|
|||||||
if resume_size > 0:
|
if resume_size > 0:
|
||||||
headers["Range"] = f"bytes={resume_size}-"
|
headers["Range"] = f"bytes={resume_size}-"
|
||||||
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||||
r.raise_for_status()
|
_raise_for_status(r)
|
||||||
content_length = r.headers.get("Content-Length")
|
content_length = r.headers.get("Content-Length")
|
||||||
total = resume_size + int(content_length) if content_length is not None else None
|
total = resume_size + int(content_length) if content_length is not None else None
|
||||||
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
||||||
@@ -1970,7 +2001,7 @@ def get_from_cache(
|
|||||||
if not local_files_only:
|
if not local_files_only:
|
||||||
try:
|
try:
|
||||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
||||||
r.raise_for_status()
|
_raise_for_status(r)
|
||||||
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
||||||
# We favor a custom header indicating the etag of the linked resource, and
|
# We favor a custom header indicating the etag of the linked resource, and
|
||||||
# we fallback to the regular etag header.
|
# we fallback to the regular etag header.
|
||||||
@@ -2081,6 +2112,56 @@ def get_from_cache(
|
|||||||
return cache_path
|
return cache_path
|
||||||
|
|
||||||
|
|
||||||
|
def has_file(
|
||||||
|
path_or_repo: Union[str, os.PathLike],
|
||||||
|
filename: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
mirror: Optional[str] = None,
|
||||||
|
proxies: Optional[Dict[str, str]] = None,
|
||||||
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Checks if a repo contains a given file wihtout downloading it. Works for remote repos and local folders.
|
||||||
|
|
||||||
|
<Tip warning={false}>
|
||||||
|
|
||||||
|
This function will raise an error if the repository `path_or_repo` is not valid or if `revision` does not exist for
|
||||||
|
this repo, but will return False for regular connection errors.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
if os.path.isdir(path_or_repo):
|
||||||
|
return os.path.isfile(os.path.join(path_or_repo, filename))
|
||||||
|
|
||||||
|
url = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=mirror)
|
||||||
|
|
||||||
|
headers = {"user-agent": http_user_agent()}
|
||||||
|
if isinstance(use_auth_token, str):
|
||||||
|
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||||
|
elif use_auth_token:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||||
|
headers["authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
|
||||||
|
try:
|
||||||
|
_raise_for_status(r)
|
||||||
|
return True
|
||||||
|
except RepositoryNotFoundError as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise EnvironmentError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.")
|
||||||
|
except RevisionNotFoundError as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||||
|
"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
|
||||||
|
)
|
||||||
|
except requests.HTTPError:
|
||||||
|
# We return false for EntryNotFoundError (logical) as well as any connection error.
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_list_of_files(
|
def get_list_of_files(
|
||||||
path_or_repo: Union[str, os.PathLike],
|
path_or_repo: Union[str, os.PathLike],
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
|
|||||||
@@ -26,16 +26,21 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
|
|||||||
from flax.serialization import from_bytes, to_bytes
|
from flax.serialization import from_bytes, to_bytes
|
||||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
EntryNotFoundError,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
cached_path,
|
cached_path,
|
||||||
copy_func,
|
copy_func,
|
||||||
|
has_file,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
@@ -450,17 +455,25 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||||
# Load from a Flax checkpoint
|
# Load from a Flax checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||||
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
|
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||||
|
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
||||||
|
"weights."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
|
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||||
f"{pretrained_model_name_or_path} or `from_pt` set to False"
|
f"{pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
|
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
|
||||||
archive_file = hf_bucket_url(
|
archive_file = hf_bucket_url(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
|
filename=filename,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -476,15 +489,59 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except RepositoryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||||
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||||
|
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||||
|
"login` and pass `use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except RevisionNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||||
|
"this model name. Check the model page at "
|
||||||
|
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||||
|
)
|
||||||
|
except EntryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
if filename == FLAX_WEIGHTS_NAME:
|
||||||
|
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
|
||||||
|
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
|
||||||
|
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
|
||||||
|
"those weights."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
|
||||||
|
f"or {WEIGHTS_NAME}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||||
|
)
|
||||||
|
except HTTPError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
|
||||||
|
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
|
||||||
|
f"{FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
|
||||||
|
"Checkout your internet connection or see how to run the library in offline mode at "
|
||||||
|
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||||
|
)
|
||||||
except EnvironmentError as err:
|
except EnvironmentError as err:
|
||||||
logger.error(err)
|
logger.error(err)
|
||||||
msg = (
|
raise EnvironmentError(
|
||||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
|
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
raise EnvironmentError(msg)
|
|
||||||
|
|
||||||
if resolved_archive_file == archive_file:
|
if resolved_archive_file == archive_file:
|
||||||
logger.info(f"loading weights file {archive_file}")
|
logger.info(f"loading weights file {archive_file}")
|
||||||
|
|||||||
@@ -32,16 +32,21 @@ from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
|||||||
from tensorflow.python.keras.saving import hdf5_format
|
from tensorflow.python.keras.saving import hdf5_format
|
||||||
|
|
||||||
from huggingface_hub import Repository, list_repo_files
|
from huggingface_hub import Repository, list_repo_files
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
EntryNotFoundError,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
cached_path,
|
cached_path,
|
||||||
copy_func,
|
copy_func,
|
||||||
|
has_file,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
@@ -1542,19 +1547,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||||
# Load from a TF 2.0 checkpoint
|
# Load from a TF 2.0 checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||||
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
|
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||||
|
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
||||||
|
"weights."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME]} found in directory "
|
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||||
f"{pretrained_model_name_or_path} or `from_pt` set to False"
|
f"{pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
else:
|
else:
|
||||||
|
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
|
||||||
archive_file = hf_bucket_url(
|
archive_file = hf_bucket_url(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
|
filename=filename,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
mirror=mirror,
|
mirror=mirror,
|
||||||
)
|
)
|
||||||
@@ -1571,15 +1584,65 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except RepositoryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||||
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||||
|
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||||
|
"login` and pass `use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except RevisionNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||||
|
"this model name. Check the model page at "
|
||||||
|
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||||
|
)
|
||||||
|
except EntryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
if filename == TF2_WEIGHTS_NAME:
|
||||||
|
has_file_kwargs = {
|
||||||
|
"revision": revision,
|
||||||
|
"mirror": mirror,
|
||||||
|
"proxies": proxies,
|
||||||
|
"use_auth_token": use_auth_token,
|
||||||
|
}
|
||||||
|
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
|
||||||
|
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
|
||||||
|
"those weights."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
|
||||||
|
f"or {WEIGHTS_NAME}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||||
|
)
|
||||||
|
except HTTPError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
|
||||||
|
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
|
||||||
|
f"{TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
|
||||||
|
"Checkout your internet connection or see how to run the library in offline mode at "
|
||||||
|
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||||
|
)
|
||||||
except EnvironmentError as err:
|
except EnvironmentError as err:
|
||||||
logger.error(err)
|
logger.error(err)
|
||||||
msg = (
|
raise EnvironmentError(
|
||||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
|
f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
raise EnvironmentError(msg)
|
|
||||||
if resolved_archive_file == archive_file:
|
if resolved_archive_file == archive_file:
|
||||||
logger.info(f"loading weights file {archive_file}")
|
logger.info(f"loading weights file {archive_file}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ from packaging import version
|
|||||||
from torch import Tensor, device, nn
|
from torch import Tensor, device, nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
from .activations import get_activation
|
from .activations import get_activation
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
@@ -36,10 +38,14 @@ from .file_utils import (
|
|||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
TF_WEIGHTS_NAME,
|
TF_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
EntryNotFoundError,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
cached_path,
|
cached_path,
|
||||||
copy_func,
|
copy_func,
|
||||||
|
has_file,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
@@ -1292,10 +1298,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
|
elif os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||||
|
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
|
||||||
|
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
||||||
|
"weights."
|
||||||
|
)
|
||||||
|
elif os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
|
||||||
|
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
||||||
|
"weights."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index', FLAX_WEIGHTS_NAME]} found in "
|
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
|
||||||
f"directory {pretrained_model_name_or_path} or `from_tf` and `from_flax` set to False."
|
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
@@ -1334,20 +1355,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except RepositoryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||||
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||||
|
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||||
|
"login` and pass `use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except RevisionNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||||
|
"this model name. Check the model page at "
|
||||||
|
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||||
|
)
|
||||||
|
except EntryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
if filename == WEIGHTS_NAME:
|
||||||
|
has_file_kwargs = {
|
||||||
|
"revision": revision,
|
||||||
|
"mirror": mirror,
|
||||||
|
"proxies": proxies,
|
||||||
|
"use_auth_token": use_auth_token,
|
||||||
|
}
|
||||||
|
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
||||||
|
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
||||||
|
"weights."
|
||||||
|
)
|
||||||
|
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
||||||
|
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
||||||
|
"weights."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, "
|
||||||
|
f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||||
|
)
|
||||||
|
except HTTPError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
|
||||||
|
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
|
||||||
|
f"{WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}.\n"
|
||||||
|
"Checkout your internet connection or see how to run the library in offline mode at "
|
||||||
|
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||||
|
)
|
||||||
except EnvironmentError as err:
|
except EnvironmentError as err:
|
||||||
logger.error(err)
|
logger.error(err)
|
||||||
msg = (
|
raise EnvironmentError(
|
||||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}\n\n"
|
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
|
||||||
|
f"{FLAX_WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if revision is not None:
|
|
||||||
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
|
|
||||||
|
|
||||||
raise EnvironmentError(msg)
|
|
||||||
|
|
||||||
if resolved_archive_file == archive_file:
|
if resolved_archive_file == archive_file:
|
||||||
logger.info(f"loading weights file {archive_file}")
|
logger.info(f"loading weights file {archive_file}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -18,13 +18,13 @@ import importlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
cached_path,
|
cached_path,
|
||||||
get_list_of_files,
|
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
@@ -333,16 +333,6 @@ def get_tokenizer_config(
|
|||||||
logger.info("Offline mode: forcing local_files_only=True")
|
logger.info("Offline mode: forcing local_files_only=True")
|
||||||
local_files_only = True
|
local_files_only = True
|
||||||
|
|
||||||
# Will raise a ValueError if `pretrained_model_name_or_path` is not a valid path or model identifier
|
|
||||||
repo_files = get_list_of_files(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
revision=revision,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
if TOKENIZER_CONFIG_FILE not in [Path(f).name for f in repo_files]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
|
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
|
||||||
@@ -363,6 +353,21 @@ def get_tokenizer_config(
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except RepositoryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||||
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
||||||
|
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
||||||
|
"`huggingface-cli login` and pass `use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except RevisionNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||||
|
"for this model name. Check the model page at "
|
||||||
|
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||||
|
)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@@ -31,13 +31,16 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequenc
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
import requests
|
from requests import HTTPError
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
|
EntryNotFoundError,
|
||||||
ExplicitEnum,
|
ExplicitEnum,
|
||||||
PaddingStrategy,
|
PaddingStrategy,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
TensorType,
|
TensorType,
|
||||||
_is_jax,
|
_is_jax,
|
||||||
_is_numpy,
|
_is_numpy,
|
||||||
@@ -1704,9 +1707,28 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
except requests.exceptions.HTTPError as err:
|
except RepositoryNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||||
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
||||||
|
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
||||||
|
"`huggingface-cli login` and pass `use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except RevisionNotFoundError as err:
|
||||||
|
logger.error(err)
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||||
|
"for this model name. Check the model page at "
|
||||||
|
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||||
|
)
|
||||||
|
except EntryNotFoundError:
|
||||||
|
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
|
||||||
|
resolved_vocab_files[file_id] = None
|
||||||
|
|
||||||
|
except HTTPError as err:
|
||||||
if "404 Client Error" in str(err):
|
if "404 Client Error" in str(err):
|
||||||
logger.debug(err)
|
logger.debug(f"Connection problem to access {file_path}.")
|
||||||
resolved_vocab_files[file_id] = None
|
resolved_vocab_files[file_id] = None
|
||||||
else:
|
else:
|
||||||
raise err
|
raise err
|
||||||
@@ -1718,18 +1740,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
|
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
|
||||||
msg = (
|
raise EnvironmentError(
|
||||||
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\n\n"
|
f"containing all relevant tokenizer files."
|
||||||
)
|
)
|
||||||
|
|
||||||
if revision is not None:
|
|
||||||
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
|
|
||||||
|
|
||||||
raise EnvironmentError(msg)
|
|
||||||
|
|
||||||
for file_id, file_path in vocab_files.items():
|
for file_id, file_path in vocab_files.items():
|
||||||
if file_id not in resolved_vocab_files:
|
if file_id not in resolved_vocab_files:
|
||||||
continue
|
continue
|
||||||
@@ -3504,9 +3521,13 @@ def get_fast_tokenizer_file(
|
|||||||
`str`: The tokenizer file to use.
|
`str`: The tokenizer file to use.
|
||||||
"""
|
"""
|
||||||
# Inspect all files from the repo/folder.
|
# Inspect all files from the repo/folder.
|
||||||
all_files = get_list_of_files(
|
try:
|
||||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
all_files = get_list_of_files(
|
||||||
)
|
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return FULL_TOKENIZER_FILE
|
||||||
|
|
||||||
tokenizer_files_map = {}
|
tokenizer_files_map = {}
|
||||||
for file_name in all_files:
|
for file_name in all_files:
|
||||||
search = _re_tokenizer_file.search(file_name)
|
search = _re_tokenizer_file.search(file_name)
|
||||||
|
|||||||
@@ -83,3 +83,22 @@ class AutoConfigTest(unittest.TestCase):
|
|||||||
finally:
|
finally:
|
||||||
if "new-model" in CONFIG_MAPPING._extra_content:
|
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||||
del CONFIG_MAPPING._extra_content["new-model"]
|
del CONFIG_MAPPING._extra_content["new-model"]
|
||||||
|
|
||||||
|
def test_repo_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||||
|
):
|
||||||
|
_ = AutoConfig.from_pretrained("bert-base")
|
||||||
|
|
||||||
|
def test_revision_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||||
|
):
|
||||||
|
_ = AutoConfig.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||||
|
|
||||||
|
def test_configuration_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError,
|
||||||
|
"hf-internal-testing/no-config-test-repo does not appear to have a file named config.json.",
|
||||||
|
):
|
||||||
|
_ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo")
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor
|
from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor
|
||||||
|
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||||
@@ -62,3 +63,22 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
def test_feature_extractor_from_local_file(self):
|
def test_feature_extractor_from_local_file(self):
|
||||||
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
|
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
|
||||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||||
|
|
||||||
|
def test_repo_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||||
|
):
|
||||||
|
_ = AutoFeatureExtractor.from_pretrained("bert-base")
|
||||||
|
|
||||||
|
def test_revision_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||||
|
):
|
||||||
|
_ = AutoFeatureExtractor.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||||
|
|
||||||
|
def test_feature_extractor_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError,
|
||||||
|
"hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.",
|
||||||
|
):
|
||||||
|
_ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model")
|
||||||
|
|||||||
@@ -17,17 +17,22 @@ import importlib
|
|||||||
import io
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import requests
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
# Try to import everything from transformers to ensure every object can be loaded.
|
# Try to import everything from transformers to ensure every object can be loaded.
|
||||||
from transformers import * # noqa F406
|
from transformers import * # noqa F406
|
||||||
from transformers.file_utils import (
|
from transformers.file_utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
|
FLAX_WEIGHTS_NAME,
|
||||||
|
TF2_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
ContextManagers,
|
ContextManagers,
|
||||||
|
EntryNotFoundError,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
filename_to_url,
|
filename_to_url,
|
||||||
get_from_cache,
|
get_from_cache,
|
||||||
|
has_file,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
||||||
@@ -83,13 +88,19 @@ class GetFromCacheTests(unittest.TestCase):
|
|||||||
def test_file_not_found(self):
|
def test_file_not_found(self):
|
||||||
# Valid revision (None) but missing file.
|
# Valid revision (None) but missing file.
|
||||||
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
|
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
|
||||||
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
|
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
|
||||||
|
_ = get_from_cache(url)
|
||||||
|
|
||||||
|
def test_model_not_found(self):
|
||||||
|
# Invalid model file.
|
||||||
|
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
|
||||||
|
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
|
||||||
_ = get_from_cache(url)
|
_ = get_from_cache(url)
|
||||||
|
|
||||||
def test_revision_not_found(self):
|
def test_revision_not_found(self):
|
||||||
# Valid file but missing revision
|
# Valid file but missing revision
|
||||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
|
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
|
||||||
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
|
with self.assertRaisesRegex(RevisionNotFoundError, "404 Client Error"):
|
||||||
_ = get_from_cache(url)
|
_ = get_from_cache(url)
|
||||||
|
|
||||||
def test_standard_object(self):
|
def test_standard_object(self):
|
||||||
@@ -112,6 +123,11 @@ class GetFromCacheTests(unittest.TestCase):
|
|||||||
metadata = filename_to_url(filepath)
|
metadata = filename_to_url(filepath)
|
||||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
||||||
|
|
||||||
|
def test_has_file(self):
|
||||||
|
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
|
||||||
|
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
|
||||||
|
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", FLAX_WEIGHTS_NAME))
|
||||||
|
|
||||||
|
|
||||||
class ContextManagerTests(unittest.TestCase):
|
class ContextManagerTests(unittest.TestCase):
|
||||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||||
|
|||||||
@@ -389,3 +389,30 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
):
|
):
|
||||||
if NewModelConfig in mapping._extra_content:
|
if NewModelConfig in mapping._extra_content:
|
||||||
del mapping._extra_content[NewModelConfig]
|
del mapping._extra_content[NewModelConfig]
|
||||||
|
|
||||||
|
def test_repo_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||||
|
):
|
||||||
|
_ = AutoModel.from_pretrained("bert-base")
|
||||||
|
|
||||||
|
def test_revision_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||||
|
):
|
||||||
|
_ = AutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||||
|
|
||||||
|
def test_model_file_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError,
|
||||||
|
"hf-internal-testing/config-no-model does not appear to have a file named pytorch_model.bin",
|
||||||
|
):
|
||||||
|
_ = AutoModel.from_pretrained("hf-internal-testing/config-no-model")
|
||||||
|
|
||||||
|
def test_model_from_tf_suggestion(self):
|
||||||
|
with self.assertRaisesRegex(EnvironmentError, "Use `from_tf=True` to load this model"):
|
||||||
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
|
||||||
|
|
||||||
|
def test_model_from_flax_suggestion(self):
|
||||||
|
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
|
||||||
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoTokenizer, BertConfig, TensorType, is_flax_available
|
from transformers import AutoConfig, AutoTokenizer, BertConfig, TensorType, is_flax_available
|
||||||
from transformers.testing_utils import require_flax, slow
|
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_flax, slow
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
@@ -76,3 +76,26 @@ class FlaxAutoModelTest(unittest.TestCase):
|
|||||||
return model(**kwargs)
|
return model(**kwargs)
|
||||||
|
|
||||||
eval(**tokens).block_until_ready()
|
eval(**tokens).block_until_ready()
|
||||||
|
|
||||||
|
def test_repo_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||||
|
):
|
||||||
|
_ = FlaxAutoModel.from_pretrained("bert-base")
|
||||||
|
|
||||||
|
def test_revision_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||||
|
):
|
||||||
|
_ = FlaxAutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||||
|
|
||||||
|
def test_model_file_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError,
|
||||||
|
"hf-internal-testing/config-no-model does not appear to have a file named flax_model.msgpack",
|
||||||
|
):
|
||||||
|
_ = FlaxAutoModel.from_pretrained("hf-internal-testing/config-no-model")
|
||||||
|
|
||||||
|
def test_model_from_pt_suggestion(self):
|
||||||
|
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
|
||||||
|
_ = FlaxAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
@@ -309,3 +309,26 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
):
|
):
|
||||||
if NewModelConfig in mapping._extra_content:
|
if NewModelConfig in mapping._extra_content:
|
||||||
del mapping._extra_content[NewModelConfig]
|
del mapping._extra_content[NewModelConfig]
|
||||||
|
|
||||||
|
def test_repo_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||||
|
):
|
||||||
|
_ = TFAutoModel.from_pretrained("bert-base")
|
||||||
|
|
||||||
|
def test_revision_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||||
|
):
|
||||||
|
_ = TFAutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||||
|
|
||||||
|
def test_model_file_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError,
|
||||||
|
"hf-internal-testing/config-no-model does not appear to have a file named tf_model.h5",
|
||||||
|
):
|
||||||
|
_ = TFAutoModel.from_pretrained("hf-internal-testing/config-no-model")
|
||||||
|
|
||||||
|
def test_model_from_pt_suggestion(self):
|
||||||
|
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
|
||||||
|
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
|
|||||||
@@ -150,7 +150,8 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
def test_tokenizer_identifier_non_existent(self):
|
def test_tokenizer_identifier_non_existent(self):
|
||||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, ".*is not a local path or a model identifier on the model Hub. Did you make a typo?"
|
EnvironmentError,
|
||||||
|
"julien-c/herlolip-not-exists is not a local folder and is not a valid model identifier",
|
||||||
):
|
):
|
||||||
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
|
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
|
||||||
|
|
||||||
@@ -310,3 +311,15 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
del CONFIG_MAPPING._extra_content["new-model"]
|
del CONFIG_MAPPING._extra_content["new-model"]
|
||||||
if NewConfig in TOKENIZER_MAPPING._extra_content:
|
if NewConfig in TOKENIZER_MAPPING._extra_content:
|
||||||
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||||
|
|
||||||
|
def test_repo_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||||
|
):
|
||||||
|
_ = AutoTokenizer.from_pretrained("bert-base")
|
||||||
|
|
||||||
|
def test_revision_not_found(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||||
|
):
|
||||||
|
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ SPECIAL_MODULE_TO_TEST_MAP = {
|
|||||||
"modeling_tf_utils.py": ["test_modeling_tf_common.py", "test_modeling_tf_core.py"],
|
"modeling_tf_utils.py": ["test_modeling_tf_common.py", "test_modeling_tf_core.py"],
|
||||||
"modeling_utils.py": ["test_modeling_common.py", "test_offline.py"],
|
"modeling_utils.py": ["test_modeling_common.py", "test_offline.py"],
|
||||||
"models/auto/modeling_auto.py": ["test_modeling_auto.py", "test_modeling_tf_pytorch.py", "test_modeling_bort.py"],
|
"models/auto/modeling_auto.py": ["test_modeling_auto.py", "test_modeling_tf_pytorch.py", "test_modeling_bort.py"],
|
||||||
"models/auto/modeling_flax_auto.py": "test_flax_auto.py",
|
"models/auto/modeling_flax_auto.py": "test_modeling_flax_auto.py",
|
||||||
"models/auto/modeling_tf_auto.py": [
|
"models/auto/modeling_tf_auto.py": [
|
||||||
"test_modeling_tf_auto.py",
|
"test_modeling_tf_auto.py",
|
||||||
"test_modeling_tf_pytorch.py",
|
"test_modeling_tf_pytorch.py",
|
||||||
|
|||||||
Reference in New Issue
Block a user