[core] Large/full refactor of from_pretrained (#36033)

* squash everything together
start to simplify inner logic

Update modeling_utils.py

Update modeling_utils.py

Update modeling_utils.py

Update modeling_utils.py

continue refactor

fix

small fixes

add type hints/docstring

Update modeling_utils.py

remove _fast_init

keep improving

Update modeling_utils.py

Update modeling_utils.py

new first tp loading version

style

fix weird in-place op

trigger CIs

Update modeling_utils.py

much clearer renaming of keys

fix

update

Update test_modeling_common.py

trigger CIs

update

update

style

Update modeling_utils.py

Update modeling_utils.py

Update modeling_utils.py

fix

fast download first prototype

remove old function

remove old functions

Remove unused function and move back _get_tp_registry

fix tp plan registry

simplify

CIs

Update hub.py

Update modeling_utils.py

simplify

simplify renaming logic

remove unused check

add sanity check back (a test depends on it)

Update modeling_utils.py

finalize sound renaming logic

style

add forgotten check

Update modeling_utils.py

add key_mapping keyword

style

Update modeling_utils.py

add comment

minor updates

minor change for clarity

fix small prefix issue and simplify

style

trigger CIs

typo fix

Post rebase fix

post rebase cleanup

simplify tp

typo

oupsi

typo

correctly escape

improvements based on Marc's review

finalize Marc's review comments

 squash everything

* improve

* Update modeling_utils.py

* Update modeling_utils.py

* fix

* Update modeling_utils.py

* Update modeling_utils.py

* style

* Update modeling_utils.py

* simplify

* style

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* fix dtype issue

* Update modeling_utils.py

* style

* remove test that does not make sense

* style

* small fixes

* style

* fix

* cleanup after rebase

* style

* typo

* escape

* tp for task specific top modules

* Update modeling_utils.py

* Update modeling_utils.py

* fix allocation

* CIs

* CIs

* CIs

* improve docstring

* CIs

* Update modeling_utils.py

* fix
This commit is contained in:
Cyril Vallez
2025-03-12 13:39:25 +01:00
committed by GitHub
parent 7652804d23
commit 071a161d3e
15 changed files with 1525 additions and 1542 deletions

View File

@@ -71,7 +71,6 @@ from .utils import (
copy_func, copy_func,
default_cache_path, default_cache_path,
define_sagemaker_information, define_sagemaker_information,
get_file_from_repo,
get_torch_version, get_torch_version,
has_file, has_file,
http_user_agent, http_user_agent,

View File

@@ -306,7 +306,7 @@ def deepspeed_config():
return None return None
def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): def _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_to_params_buffers=False):
""" """
Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers` Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
tensor parallelism API. tensor parallelism API.
@@ -349,7 +349,7 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, a
if child is not None: if child is not None:
load(child, state_dict, prefix + name + ".", assign_to_params_buffers) load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) load(model_to_load, state_dict, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it. # it's safe to delete it.
del state_dict del state_dict

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,7 @@ from typing import Dict, Optional, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...feature_extraction_utils import FeatureExtractionMixin from ...feature_extraction_utils import FeatureExtractionMixin
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo, logging from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
from .auto_factory import _LazyAutoMapping from .auto_factory import _LazyAutoMapping
from .configuration_auto import ( from .configuration_auto import (
CONFIG_MAPPING_NAMES, CONFIG_MAPPING_NAMES,
@@ -220,7 +220,7 @@ def get_feature_extractor_config(
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token token = use_auth_token
resolved_config_file = get_file_from_repo( resolved_config_file = cached_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
FEATURE_EXTRACTOR_NAME, FEATURE_EXTRACTOR_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
@@ -230,6 +230,9 @@ def get_feature_extractor_config(
token=token, token=token,
revision=revision, revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
) )
if resolved_config_file is None: if resolved_config_file is None:
logger.info( logger.info(

View File

@@ -29,7 +29,7 @@ from ...image_processing_utils_fast import BaseImageProcessorFast
from ...utils import ( from ...utils import (
CONFIG_NAME, CONFIG_NAME,
IMAGE_PROCESSOR_NAME, IMAGE_PROCESSOR_NAME,
get_file_from_repo, cached_file,
is_timm_config_dict, is_timm_config_dict,
is_timm_local_checkpoint, is_timm_local_checkpoint,
is_torchvision_available, is_torchvision_available,
@@ -288,7 +288,7 @@ def get_image_processor_config(
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token token = use_auth_token
resolved_config_file = get_file_from_repo( resolved_config_file = cached_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
IMAGE_PROCESSOR_NAME, IMAGE_PROCESSOR_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
@@ -298,6 +298,9 @@ def get_image_processor_config(
token=token, token=token,
revision=revision, revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
) )
if resolved_config_file is None: if resolved_config_file is None:
logger.info( logger.info(

View File

@@ -28,7 +28,7 @@ from ...feature_extraction_utils import FeatureExtractionMixin
from ...image_processing_utils import ImageProcessingMixin from ...image_processing_utils import ImageProcessingMixin
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
from ...tokenization_utils import TOKENIZER_CONFIG_FILE from ...tokenization_utils import TOKENIZER_CONFIG_FILE
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, get_file_from_repo, logging from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging
from .auto_factory import _LazyAutoMapping from .auto_factory import _LazyAutoMapping
from .configuration_auto import ( from .configuration_auto import (
CONFIG_MAPPING_NAMES, CONFIG_MAPPING_NAMES,
@@ -254,15 +254,21 @@ class AutoProcessor:
processor_auto_map = None processor_auto_map = None
# First, let's see if we have a processor or preprocessor config. # First, let's see if we have a processor or preprocessor config.
# Filter the kwargs for `get_file_from_repo`. # Filter the kwargs for `cached_file`.
get_file_from_repo_kwargs = { cached_file_kwargs = {
key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs key: kwargs[key] for key in inspect.signature(cached_file).parameters.keys() if key in kwargs
} }
# We don't want to raise
cached_file_kwargs.update(
{
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_raise_exceptions_for_connection_errors": False,
}
)
# Let's start by checking whether the processor class is saved in a processor config # Let's start by checking whether the processor class is saved in a processor config
processor_config_file = get_file_from_repo( processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs)
pretrained_model_name_or_path, PROCESSOR_NAME, **get_file_from_repo_kwargs
)
if processor_config_file is not None: if processor_config_file is not None:
config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs)
processor_class = config_dict.get("processor_class", None) processor_class = config_dict.get("processor_class", None)
@@ -271,8 +277,8 @@ class AutoProcessor:
if processor_class is None: if processor_class is None:
# If not found, let's check whether the processor class is saved in an image processor config # If not found, let's check whether the processor class is saved in an image processor config
preprocessor_config_file = get_file_from_repo( preprocessor_config_file = cached_file(
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **get_file_from_repo_kwargs pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
) )
if preprocessor_config_file is not None: if preprocessor_config_file is not None:
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
@@ -291,8 +297,8 @@ class AutoProcessor:
if processor_class is None: if processor_class is None:
# Next, let's check whether the processor class is saved in a tokenizer # Next, let's check whether the processor class is saved in a tokenizer
tokenizer_config_file = get_file_from_repo( tokenizer_config_file = cached_file(
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs
) )
if tokenizer_config_file is not None: if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as reader: with open(tokenizer_config_file, encoding="utf-8") as reader:

View File

@@ -25,7 +25,7 @@ import numpy as np
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessorMixin
from ...utils import logging from ...utils import logging
from ...utils.hub import get_file_from_repo from ...utils.hub import cached_file
from ..auto import AutoTokenizer from ..auto import AutoTokenizer
@@ -86,7 +86,7 @@ class BarkProcessor(ProcessorMixin):
""" """
if speaker_embeddings_dict_path is not None: if speaker_embeddings_dict_path is not None:
speaker_embeddings_path = get_file_from_repo( speaker_embeddings_path = cached_file(
pretrained_processor_name_or_path, pretrained_processor_name_or_path,
speaker_embeddings_dict_path, speaker_embeddings_dict_path,
subfolder=kwargs.pop("subfolder", None), subfolder=kwargs.pop("subfolder", None),
@@ -97,6 +97,9 @@ class BarkProcessor(ProcessorMixin):
local_files_only=kwargs.pop("local_files_only", False), local_files_only=kwargs.pop("local_files_only", False),
token=kwargs.pop("use_auth_token", None), token=kwargs.pop("use_auth_token", None),
revision=kwargs.pop("revision", None), revision=kwargs.pop("revision", None),
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
) )
if speaker_embeddings_path is None: if speaker_embeddings_path is None:
logger.warning( logger.warning(
@@ -182,7 +185,7 @@ class BarkProcessor(ProcessorMixin):
f"Voice preset unrecognized, missing {key} as a key in self.speaker_embeddings[{voice_preset}]." f"Voice preset unrecognized, missing {key} as a key in self.speaker_embeddings[{voice_preset}]."
) )
path = get_file_from_repo( path = cached_file(
self.speaker_embeddings.get("repo_or_path", "/"), self.speaker_embeddings.get("repo_or_path", "/"),
voice_preset_paths[key], voice_preset_paths[key],
subfolder=kwargs.pop("subfolder", None), subfolder=kwargs.pop("subfolder", None),
@@ -193,6 +196,9 @@ class BarkProcessor(ProcessorMixin):
local_files_only=kwargs.pop("local_files_only", False), local_files_only=kwargs.pop("local_files_only", False),
token=kwargs.pop("use_auth_token", None), token=kwargs.pop("use_auth_token", None),
revision=kwargs.pop("revision", None), revision=kwargs.pop("revision", None),
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
) )
if path is None: if path is None:
raise ValueError( raise ValueError(

View File

@@ -544,7 +544,7 @@ class CvtPreTrainedModel(PreTrainedModel):
elif isinstance(module, CvtStage): elif isinstance(module, CvtStage):
if self.config.cls_token[module.stage]: if self.config.cls_token[module.stage]:
module.cls_token.data = nn.init.trunc_normal_( module.cls_token.data = nn.init.trunc_normal_(
torch.zeros(1, 1, self.config.embed_dim[-1]), mean=0.0, std=self.config.initializer_range module.cls_token.data, mean=0.0, std=self.config.initializer_range
) )

View File

@@ -35,7 +35,7 @@ from torch import Tensor
from vissl.models.model_helpers import get_trunk_forward_outputs from vissl.models.model_helpers import get_trunk_forward_outputs
from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict
from transformers.utils import logging from transformers.utils import logging
@@ -244,14 +244,18 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
our_model_func = RegNetModel our_model_func = RegNetModel
if "in1k" in model_name: if "in1k" in model_name:
our_model_func = RegNetForImageClassification our_model_func = RegNetForImageClassification
our_model = our_model_func(our_config) with torch.device("meta"):
# place our model to the meta device (so remove all the weights) our_model = our_model_func(our_config)
our_model.to(torch.device("meta"))
logger.info("Loading state_dict in our model.") logger.info("Loading state_dict in our model.")
# load state dict # load state dict
state_dict_keys = our_model.state_dict().keys() state_dict_keys = our_model.state_dict().keys()
PreTrainedModel._load_pretrained_model_low_mem( state_dict = load_state_dict(save_directory / f"{model_name}.pth", weights_only=True)
our_model, state_dict_keys, [save_directory / f"{model_name}.pth"] fixed_state_dict = state_dict = {our_model._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()}
_load_state_dict_into_meta_model(
our_model,
fixed_state_dict,
start_prefix="",
expected_keys=state_dict_keys,
) )
logger.info("Finally, pushing!") logger.info("Finally, pushing!")
# push it to hub # push it to hub

View File

@@ -113,7 +113,7 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
Override original method to fix state_dict keys on load for cases when weights are loaded Override original method to fix state_dict keys on load for cases when weights are loaded
without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint). without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint).
""" """
state_dict = self._fix_state_dict_keys_on_load(state_dict) state_dict = {self._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()}
return super().load_state_dict(state_dict, *args, **kwargs) return super().load_state_dict(state_dict, *args, **kwargs)
def _init_weights(self, module): def _init_weights(self, module):

View File

@@ -91,7 +91,6 @@ from .hub import (
define_sagemaker_information, define_sagemaker_information,
download_url, download_url,
extract_commit_hash, extract_commit_hash,
get_file_from_repo,
has_file, has_file,
http_user_agent, http_user_agent,
is_offline_mode, is_offline_mode,

View File

@@ -40,6 +40,7 @@ from huggingface_hub import (
create_repo, create_repo,
hf_hub_download, hf_hub_download,
hf_hub_url, hf_hub_url,
snapshot_download,
try_to_load_from_cache, try_to_load_from_cache,
) )
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
@@ -47,7 +48,6 @@ from huggingface_hub.utils import (
EntryNotFoundError, EntryNotFoundError,
GatedRepoError, GatedRepoError,
HfHubHTTPError, HfHubHTTPError,
HFValidationError,
LocalEntryNotFoundError, LocalEntryNotFoundError,
OfflineModeIsEnabled, OfflineModeIsEnabled,
RepositoryNotFoundError, RepositoryNotFoundError,
@@ -69,7 +69,6 @@ from .import_utils import (
is_torch_available, is_torch_available,
is_training_run_on_sagemaker, is_training_run_on_sagemaker,
) )
from .logging import tqdm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -209,21 +208,7 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
def cached_file( def cached_file(
path_or_repo_id: Union[str, os.PathLike], path_or_repo_id: Union[str, os.PathLike],
filename: str, filename: str,
cache_dir: Optional[Union[str, os.PathLike]] = None, **kwargs,
force_download: bool = False,
resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_gated_repo: bool = True,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
**deprecated_kwargs,
) -> Optional[str]: ) -> Optional[str]:
""" """
Tries to locate a file in a local folder and repo, downloads and cache it if necessary. Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
@@ -231,7 +216,6 @@ def cached_file(
Args: Args:
path_or_repo_id (`str` or `os.PathLike`): path_or_repo_id (`str` or `os.PathLike`):
This can be either: This can be either:
- a string, the *model id* of a model repo on huggingface.co. - a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file. - a path to a *directory* potentially containing the file.
filename (`str`): filename (`str`):
@@ -274,6 +258,94 @@ def cached_file(
Examples: Examples:
```python
# Download a model weight from the Hub and cache it.
model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin")
```
"""
file = cached_files(path_or_repo_id=path_or_repo_id, filenames=[filename], **kwargs)
file = file[0] if file is not None else file
return file
def cached_files(
path_or_repo_id: Union[str, os.PathLike],
filenames: List[str],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_gated_repo: bool = True,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
**deprecated_kwargs,
) -> Optional[str]:
"""
Tries to locate several files in a local folder and repo, downloads and cache them if necessary.
Args:
path_or_repo_id (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
filenames (`List[str]`):
The name of all the files to locate in `path_or_repo`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible.
Will be removed in v5 of Transformers.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
repo_type (`str`, *optional*):
Specify the repo type (useful when downloading from a space for instance).
Private args:
_raise_exceptions_for_gated_repo (`bool`):
if False, do not raise an exception for gated repo error but return None.
_raise_exceptions_for_missing_entries (`bool`):
if False, do not raise an exception for missing entries but return None.
_raise_exceptions_for_connection_errors (`bool`):
if False, do not raise an exception for connection errors but return None.
_commit_hash (`str`, *optional*):
passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
<Tip>
Passing `token=True` is required when you want to use a private model.
</Tip>
Returns:
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
Examples:
```python ```python
# Download a model weight from the Hub and cache it. # Download a model weight from the Hub and cache it.
model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin") model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin")
@@ -289,144 +361,176 @@ def cached_file(
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token token = use_auth_token
# Private arguments
# _raise_exceptions_for_gated_repo: if False, do not raise an exception for gated repo error but return
# None.
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
# None.
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
# None.
# _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
# a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
if is_offline_mode() and not local_files_only: if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True") logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True local_files_only = True
if subfolder is None: if subfolder is None:
subfolder = "" subfolder = ""
# Add folder to filenames
full_filenames = [os.path.join(subfolder, file) for file in filenames]
path_or_repo_id = str(path_or_repo_id) path_or_repo_id = str(path_or_repo_id)
full_filename = os.path.join(subfolder, filename) existing_files = []
if os.path.isdir(path_or_repo_id): for filename in full_filenames:
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename) if os.path.isdir(path_or_repo_id):
if not os.path.isfile(resolved_file): resolved_file = os.path.join(path_or_repo_id, filename)
if _raise_exceptions_for_missing_entries and filename not in ["config.json", f"{subfolder}/config.json"]: if not os.path.isfile(resolved_file):
raise EnvironmentError( if _raise_exceptions_for_missing_entries and filename != os.path.join(subfolder, "config.json"):
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout " revision_ = "main" if revision is None else revision
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files." raise EnvironmentError(
) f"{path_or_repo_id} does not appear to have a file named {filename}. Checkout "
else: f"'https://huggingface.co/{path_or_repo_id}/tree/{revision_}' for available files."
return None )
return resolved_file else:
return None
existing_files.append(resolved_file)
# All files exist
if len(existing_files) == len(full_filenames):
return existing_files
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path): if isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
existing_files = []
file_counter = 0
if _commit_hash is not None and not force_download: if _commit_hash is not None and not force_download:
# If the file is cached under that commit hash, we return it directly. for filename in full_filenames:
resolved_file = try_to_load_from_cache( # If the file is cached under that commit hash, we return it directly.
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type resolved_file = try_to_load_from_cache(
) path_or_repo_id, filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
if resolved_file is not None: )
if resolved_file is not _CACHED_NO_EXIST: if resolved_file is not None:
return resolved_file if resolved_file is not _CACHED_NO_EXIST:
elif not _raise_exceptions_for_missing_entries: file_counter += 1
return None existing_files.append(resolved_file)
else: elif not _raise_exceptions_for_missing_entries:
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.") file_counter += 1
else:
raise EnvironmentError(f"Could not locate {filename} inside {path_or_repo_id}.")
# Either all the files were found, or some were _CACHED_NO_EXIST but we do not raise for missing entries
if file_counter == len(full_filenames):
return existing_files if len(existing_files) > 0 else None
user_agent = http_user_agent(user_agent) user_agent = http_user_agent(user_agent)
# download the files if needed
try: try:
# Load from URL or cache if already cached if len(full_filenames) == 1:
resolved_file = hf_hub_download( # This is slightly better for only 1 file
path_or_repo_id, hf_hub_download(
filename, path_or_repo_id,
subfolder=None if len(subfolder) == 0 else subfolder, filenames[0],
repo_type=repo_type, subfolder=None if len(subfolder) == 0 else subfolder,
revision=revision, repo_type=repo_type,
cache_dir=cache_dir, revision=revision,
user_agent=user_agent, cache_dir=cache_dir,
force_download=force_download, user_agent=user_agent,
proxies=proxies, force_download=force_download,
resume_download=resume_download, proxies=proxies,
token=token, resume_download=resume_download,
local_files_only=local_files_only, token=token,
local_files_only=local_files_only,
)
else:
snapshot_download(
path_or_repo_id,
allow_patterns=full_filenames,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except Exception as e:
# We cannot recover from them
if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError):
raise EnvironmentError(
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
"having permission to this repo either by logging in with `huggingface-cli login` or by passing "
"`token=<your_token>`"
) from e
elif isinstance(e, RevisionNotFoundError):
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
) from e
# Now we try to recover if we can find all files correctly in the cache
resolved_files = [
_get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
]
if all(file is not None for file in resolved_files):
return resolved_files
# Raise based on the flags. Note that we will raise for missing entries at the very end, even when
# not entering this Except block, as it may also happen when `snapshot_download` does not raise
if isinstance(e, GatedRepoError):
if not _raise_exceptions_for_gated_repo:
return None
raise EnvironmentError(
"You are trying to access a gated repo.\nMake sure to have access to it at "
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
) from e
elif isinstance(e, LocalEntryNotFoundError):
if not _raise_exceptions_for_connection_errors:
return None
# Here we only raise if both flags for missing entry and connection errors are True (because it can be raised
# even when `local_files_only` is True, in which case raising for connections errors only would not make sense)
elif _raise_exceptions_for_missing_entries:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load the files, and couldn't find them in the"
f" cached files.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
) from e
# snapshot_download will not raise EntryNotFoundError, but hf_hub_download can. If this is the case, it will be treated
# later on anyway and re-raised if needed
elif isinstance(e, HTTPError) and not isinstance(e, EntryNotFoundError):
if not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(
f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}"
)
resolved_files = [
_get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
]
# If there are any missing file and the flag is active, raise
if any(file is None for file in resolved_files) and _raise_exceptions_for_missing_entries:
missing_entries = [original for original, resolved in zip(full_filenames, resolved_files) if resolved is None]
# Last escape
if len(resolved_files) == 1 and missing_entries[0] == os.path.join(subfolder, "config.json"):
return None
# Now we raise for missing entries
revision_ = "main" if revision is None else revision
msg = f"a file named {missing_entries[0]}" if len(missing_entries) == 1 else f"files named {*missing_entries,}"
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have {msg}. Checkout 'https://huggingface.co/{path_or_repo_id}/tree/{revision_}'"
"for available files."
) )
except GatedRepoError as e:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision) # Remove potential missing entries (we can silently remove them at this point based on the flags)
if resolved_file is not None or not _raise_exceptions_for_gated_repo: resolved_files = [file for file in resolved_files if file is not None]
return resolved_file # Return `None` if the list is empty, coherent with other Exception when the flag is not active
raise EnvironmentError( resolved_files = None if len(resolved_files) == 0 else resolved_files
"You are trying to access a gated repo.\nMake sure to have access to it at "
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}" return resolved_files
) from e
except RepositoryNotFoundError as e:
raise EnvironmentError(
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
"having permission to this repo either by logging in with `huggingface-cli login` or by passing "
"`token=<your_token>`"
) from e
except RevisionNotFoundError as e:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
) from e
except LocalEntryNotFoundError as e:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if (
resolved_file is not None
or not _raise_exceptions_for_missing_entries
or not _raise_exceptions_for_connection_errors
):
return resolved_file
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
) from e
except EntryNotFoundError as e:
if not _raise_exceptions_for_missing_entries:
return None
if revision is None:
revision = "main"
if filename in ["config.json", f"{subfolder}/config.json"]:
return None
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files."
) from e
except HTTPError as err:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if resolved_file is not None or not _raise_exceptions_for_connection_errors:
return resolved_file
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
except HFValidationError as e:
raise EnvironmentError(
f"Incorrect path_or_model_id: '{path_or_repo_id}'. Please provide either the path to a local folder or the repo_id of a model on the Hub."
) from e
return resolved_file
# TODO: deprecate `get_file_from_repo` or document it differently? # TODO cyril: Deprecated and should be removed in 4.51
# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
def get_file_from_repo( def get_file_from_repo(
path_or_repo: Union[str, os.PathLike], *args,
filename: str, **kwargs,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
**deprecated_kwargs,
): ):
""" """
Tries to locate a file in a local folder and repo, downloads and cache it if necessary. Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
@@ -483,30 +587,15 @@ def get_file_from_repo(
tokenizer_config = get_file_from_repo("FacebookAI/xlm-roberta-base", "tokenizer_config.json") tokenizer_config = get_file_from_repo("FacebookAI/xlm-roberta-base", "tokenizer_config.json")
``` ```
""" """
use_auth_token = deprecated_kwargs.pop("use_auth_token", None) logger.warning(
if use_auth_token is not None: "`get_file_from_repo` is deprecated and will be removed in version 4.51. Use `cached_file` instead."
warnings.warn( )
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
return cached_file( return cached_file(
path_or_repo_id=path_or_repo, *args,
filename=filename,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False, _raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False, _raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False, _raise_exceptions_for_connection_errors=False,
**kwargs,
) )
@@ -1023,45 +1112,22 @@ def get_checkpoint_shard_files(
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames] shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
return shard_filenames, sharded_metadata return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub # At this stage pretrained_model_name_or_path is a model identifier on the Hub. Try to get everything from cache,
cached_filenames = [] # or download the files
# Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of cached_filenames = cached_files(
# downloaded (if interrupted). pretrained_model_name_or_path,
last_shard = try_to_load_from_cache( shard_filenames,
pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=_commit_hash,
) )
show_progress_bar = last_shard is None or force_download
for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
try:
# Load from URL
cached_filename = cached_file(
pretrained_model_name_or_path,
shard_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=_commit_hash,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
"required according to the checkpoint index."
)
except HTTPError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
" again after checking your internet connection."
)
cached_filenames.append(cached_filename)
return cached_filenames, sharded_metadata return cached_filenames, sharded_metadata

View File

@@ -2368,10 +2368,9 @@ class ModelTesterMixin:
safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"}) safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"})
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True) model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
prefix = f"{model_reloaded.base_model_prefix}."
params = dict(model_reloaded.named_parameters()) params = dict(model_reloaded.named_parameters())
params.update(dict(model_reloaded.named_buffers())) params.update(dict(model_reloaded.named_buffers()))
param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()} param_names = set(params.keys())
missing_keys = set(infos["missing_keys"]) missing_keys = set(infos["missing_keys"])
@@ -2383,9 +2382,8 @@ class ModelTesterMixin:
ptrs[id_tensor_storage(tensor)].append(name) ptrs[id_tensor_storage(tensor)].append(name)
tied_params = [names for _, names in ptrs.items() if len(names) > 1] tied_params = [names for _, names in ptrs.items() if len(names) > 1]
for group in tied_params: for group in tied_params:
group = {k[len(prefix) :] if k.startswith(prefix) else k for k in group}
# We remove the group from extra_missing if not all weights from group are in it # We remove the group from extra_missing if not all weights from group are in it
if len(group - extra_missing) > 0: if len(set(group) - extra_missing) > 0:
extra_missing = extra_missing - set(group) extra_missing = extra_missing - set(group)
self.assertEqual( self.assertEqual(
@@ -2399,15 +2397,14 @@ class ModelTesterMixin:
# Remove nonpersistent buffers from missed_missing # Remove nonpersistent buffers from missed_missing
buffers = [n for n, _ in model_reloaded.named_buffers()] buffers = [n for n, _ in model_reloaded.named_buffers()]
nonpersistent_buffers = {n for n in buffers if n not in model_reloaded.state_dict()} nonpersistent_buffers = {n for n in buffers if n not in model_reloaded.state_dict()}
nonpersistent_buffers = {
k[len(prefix) :] if k.startswith(prefix) else k for k in nonpersistent_buffers
}
missed_missing = missed_missing - nonpersistent_buffers missed_missing = missed_missing - nonpersistent_buffers
if model_reloaded._keys_to_ignore_on_load_missing is None: if model_reloaded._keys_to_ignore_on_load_missing is None:
expected_missing = set() expected_missing = set()
else: else:
expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing) expected_missing = set()
for pattern in model_reloaded._keys_to_ignore_on_load_missing:
expected_missing.update({k for k in param_names if re.search(pattern, k) is not None})
self.assertEqual( self.assertEqual(
missed_missing, missed_missing,
expected_missing, expected_missing,

View File

@@ -28,7 +28,6 @@ from transformers.utils import (
TRANSFORMERS_CACHE, TRANSFORMERS_CACHE,
WEIGHTS_NAME, WEIGHTS_NAME,
cached_file, cached_file,
get_file_from_repo,
has_file, has_file,
) )
@@ -87,14 +86,8 @@ class GetFromCacheTests(unittest.TestCase):
path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False) path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False)
self.assertIsNone(path) self.assertIsNone(path)
response_mock = mock.Mock() # Under the mock environment, hf_hub_download will always raise an HTTPError
response_mock.status_code = 500 with mock.patch("transformers.utils.hub.hf_hub_download", side_effect=HTTPError) as mock_head:
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_connection_errors=False) path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_connection_errors=False)
self.assertIsNone(path) self.assertIsNone(path)
# This check we did call the fake head request # This check we did call the fake head request
@@ -117,18 +110,45 @@ class GetFromCacheTests(unittest.TestCase):
assert has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir) assert has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir)
def test_get_file_from_repo_distant(self): def test_get_file_from_repo_distant(self):
# `get_file_from_repo` returns None if the file does not exist # should return None if the file does not exist
self.assertIsNone(get_file_from_repo("google-bert/bert-base-cased", "ahah.txt")) self.assertIsNone(
cached_file(
"google-bert/bert-base-cased",
"ahah.txt",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
)
# The function raises if the repository does not exist. # The function raises if the repository does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"): with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
get_file_from_repo("bert-base-case", CONFIG_NAME) cached_file(
"bert-base-case",
CONFIG_NAME,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
# The function raises if the revision does not exist. # The function raises if the revision does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"): with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME, revision="ahaha") cached_file(
"google-bert/bert-base-cased",
CONFIG_NAME,
revision="ahaha",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
resolved_file = get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME) resolved_file = cached_file(
"google-bert/bert-base-cased",
CONFIG_NAME,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
# The name is the cached name which is not very easy to test, so instead we load the content. # The name is the cached name which is not very easy to test, so instead we load the content.
config = json.loads(open(resolved_file, "r").read()) config = json.loads(open(resolved_file, "r").read())
self.assertEqual(config["hidden_size"], 768) self.assertEqual(config["hidden_size"], 768)
@@ -137,9 +157,26 @@ class GetFromCacheTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
filename = Path(tmp_dir) / "a.txt" filename = Path(tmp_dir) / "a.txt"
filename.touch() filename.touch()
self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename)) self.assertEqual(
cached_file(
tmp_dir,
"a.txt",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
),
str(filename),
)
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt")) self.assertIsNone(
cached_file(
tmp_dir,
"b.txt",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
)
def test_get_file_gated_repo(self): def test_get_file_gated_repo(self):
"""Test download file from a gated repo fails with correct message when not authenticated.""" """Test download file from a gated repo fails with correct message when not authenticated."""

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import copy import copy
import glob import glob
import itertools
import json import json
import os import os
import os.path import os.path
@@ -525,13 +524,12 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.vision_tower.dtype, torch.bfloat16) self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
# TODO @ARTHURZUCKER FIX THIS
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
# LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"] LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
# model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto") model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
# self.assertEqual(model.language_model.dtype, torch.float32) self.assertEqual(model.language_model.dtype, torch.float32)
# self.assertEqual(model.vision_tower.dtype, torch.bfloat16) self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
# self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@@ -540,20 +538,6 @@ class ModelUtilsTest(TestCasePlus):
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"} TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
) )
@require_torch
@unittest.skip("Broken by @arthurzucker because the fix was not correct. Knowing the context is super hard")
def test_model_from_pretrained_meta_device(self):
def is_on_meta(model_id, dtype):
with torch.device("meta"):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype)
return all(value.device.type == "meta" for value in model.state_dict().values())
model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing")
dtypes = (None, "auto", torch.float16)
for model_id, dtype in itertools.product(model_ids, dtypes):
self.assertTrue(is_on_meta(model_id, dtype))
def test_model_from_pretrained_torch_dtype(self): def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either # test that the model can be instantiated with dtype of either
# 1. explicit from_pretrained's torch_dtype argument # 1. explicit from_pretrained's torch_dtype argument