Add TimmWrapper (#34564)

* Add files

* Init

* Add TimmWrapperModel

* Fix up

* Some fixes

* Fix up

* Remove old file

* Sort out import orders

* Fix some model loading

* Compatible with pipeline and trainer

* Fix up

* Delete test_timm_model_1/config.json

* Remove accidentally commited files

* Delete src/transformers/models/modeling_timm_wrapper.py

* Remove empty imports; fix transformations applied

* Tidy up

* Add image classifcation model to special cases

* Create pretrained model; enable device_map='auto'

* Enable most tests; fix init order

* Sort imports

* [run-slow] timm_wrapper

* Pass num_classes into timm.create_model

* Remove train transforms from image processor

* Update timm creation with pretrained=False

* Fix gamma/beta issue for timm models

* Fixing gamma and beta renaming for timm models

* Simplify config and model creation

* Remove attn_implementation diff

* Fixup

* Docstrings

* Fix warning msg text according to test case

* Fix device_map auto

* Set dtype and device for pixel_values in forward

* Enable output hidden states

* Enable tests for hidden_states and model parallel

* Remove default scriptable arg

* Refactor inner model

* Update timm version

* Fix _find_mismatched_keys function

* Change inheritance for Classification model (fix weights loading with device_map)

* Minor bugfix

* Disable save pretrained for image processor

* Rename hook method for loaded keys correction

* Rename state dict keys on save, remove `timm_model` prefix, make checkpoint compatible with `timm`

* Managing num_labels <-> num_classes attributes

* Enable loading checkpoints in Trainer to resume training

* Update error message for output_hidden_states

* Add output hidden states test

* Decouple base and classification models

* Add more test cases

* Add save-load-to-timm test

* Fix test name

* Fixup

* Add do_pooling

* Add test for do_pooling

* Fix doc

* Add tests for TimmWrapperModel

* Add validation for `num_classes=0` in timm config + test for DINO checkpoint

* Adjust atol for test

* Fix docs

* dev-ci

* dev-ci

* Add tests for image processor

* Update docs

* Update init to new format

* Update docs in configuration

* Fix some docs in image processor

* Improve docs for modeling

* fix for is_timm_checkpoint

* Update code examples

* Fix header

* Fix typehint

* Increase tolerance a bit

* Fix Path

* Fixing model parallel tests

* Disable "parallel" tests

* Add comment for metadata

* Refactor AutoImageProcessor for timm wrapper loading

* Remove custom test_model_outputs_equivalence

* Add require_timm decorator

* Fix comment

* Make image processor work with older timm versions and tensor input

* Save config instead of whole model in image processor tests

* Add docstring for `image_processor_filename`

* Sanitize kwargs for timm image processor

* Fix doc style

* Update check for tensor input

* Update normalize

* Remove _load_timm_model function

---------

Co-authored-by: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Pavel Iakubovskii
2024-12-11 13:40:30 +01:00
committed by GitHub
parent bcc50cc7ce
commit 5fcf6286bf
25 changed files with 1432 additions and 129 deletions

View File

@@ -705,6 +705,8 @@
title: Swin2SR
- local: model_doc/table-transformer
title: Table Transformer
- local: model_doc/timm_wrapper
title: Timm Wrapper
- local: model_doc/upernet
title: UperNet
- local: model_doc/van

View File

@@ -321,6 +321,7 @@ Flax), PyTorch, and/or TensorFlow.
| [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ |
| [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ |
| [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ |
| [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ |
| [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ |
| [Transformer-XL](model_doc/transfo-xl) | ✅ | ✅ | ❌ |
| [TrOCR](model_doc/trocr) | ✅ | ❌ | ❌ |

View File

@@ -0,0 +1,67 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# TimmWrapper
## Overview
Helper class to enable loading timm models to be used with the transformers library and its autoclasses.
```python
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor
>>> # Load image
>>> image = Image.open(urlopen(
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))
>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
>>> # Preprocess image
>>> inputs = image_processor(image)
>>> # Forward pass
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # Get top 5 predictions
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
```
## TimmWrapperConfig
[[autodoc]] TimmWrapperConfig
## TimmWrapperImageProcessor
[[autodoc]] TimmWrapperImageProcessor
- preprocess
## TimmWrapperModel
[[autodoc]] TimmWrapperModel
- forward
## TimmWrapperForImageClassification
[[autodoc]] TimmWrapperForImageClassification
- forward

View File

@@ -42,6 +42,7 @@ from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
HfArgumentParser,
TimmWrapperImageProcessor,
Trainer,
TrainingArguments,
set_seed,
@@ -329,31 +330,36 @@ def main():
)
# Define torchvision transforms to be applied to each image.
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
if isinstance(image_processor, TimmWrapperImageProcessor):
_train_transforms = image_processor.train_transforms
_val_transforms = image_processor.val_transforms
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
_train_transforms = Compose(
[
RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
_val_transforms = Compose(
[
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
]
)
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
# Create normalization transform
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std"):
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
else:
normalize = Lambda(lambda x: x)
_train_transforms = Compose(
[
RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
_val_transforms = Compose(
[
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
]
)
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""

View File

@@ -782,6 +782,7 @@ _import_structure = {
"models.time_series_transformer": ["TimeSeriesTransformerConfig"],
"models.timesformer": ["TimesformerConfig"],
"models.timm_backbone": ["TimmBackboneConfig"],
"models.timm_wrapper": ["TimmWrapperConfig"],
"models.trocr": [
"TrOCRConfig",
"TrOCRProcessor",
@@ -1272,6 +1273,18 @@ else:
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
_import_structure["models.vit"].append("ViTImageProcessorFast")
try:
if not is_torchvision_available() and not is_timm_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_timm_and_torchvision_objects
_import_structure["utils.dummy_timm_and_torchvision_objects"] = [
name for name in dir(dummy_timm_and_torchvision_objects) if not name.startswith("_")
]
else:
_import_structure["models.timm_wrapper"].extend(["TimmWrapperImageProcessor"])
# PyTorch-backed objects
try:
if not is_torch_available():
@@ -3532,6 +3545,9 @@ else:
]
)
_import_structure["models.timm_backbone"].extend(["TimmBackbone"])
_import_structure["models.timm_wrapper"].extend(
["TimmWrapperForImageClassification", "TimmWrapperModel", "TimmWrapperPreTrainedModel"]
)
_import_structure["models.trocr"].extend(
[
"TrOCRForCausalLM",
@@ -5734,6 +5750,7 @@ if TYPE_CHECKING:
TimesformerConfig,
)
from .models.timm_backbone import TimmBackboneConfig
from .models.timm_wrapper import TimmWrapperConfig
from .models.trocr import (
TrOCRConfig,
TrOCRProcessor,
@@ -6227,6 +6244,14 @@ if TYPE_CHECKING:
from .models.rt_detr import RTDetrImageProcessorFast
from .models.vit import ViTImageProcessorFast
try:
if not is_torchvision_available() and not is_timm_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_timm_and_torchvision_objects import *
else:
from .models.timm_wrapper import TimmWrapperImageProcessor
# Modeling
try:
if not is_torch_available():
@@ -8037,6 +8062,11 @@ if TYPE_CHECKING:
TimesformerPreTrainedModel,
)
from .models.timm_backbone import TimmBackbone
from .models.timm_wrapper import (
TimmWrapperForImageClassification,
TimmWrapperModel,
TimmWrapperPreTrainedModel,
)
from .models.trocr import (
TrOCRForCausalLM,
TrOCRPreTrainedModel,

View File

@@ -37,6 +37,7 @@ from .utils import (
download_url,
extract_commit_hash,
is_remote_url,
is_timm_config_dict,
is_torch_available,
logging,
)
@@ -702,6 +703,11 @@ class PretrainedConfig(PushToHubMixin):
config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
config_dict["custom_pipelines"], pretrained_model_name_or_path
)
# timm models are not saved with the model_type in the config file
if "model_type" not in config_dict and is_timm_config_dict(config_dict):
config_dict["model_type"] = "timm_wrapper"
return config_dict, kwargs
@classmethod

View File

@@ -285,6 +285,8 @@ class ImageProcessingMixin(PushToHubMixin):
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
The name of the file in the model directory to use for the image processor config.
Returns:
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
@@ -298,6 +300,7 @@ class ImageProcessingMixin(PushToHubMixin):
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
@@ -324,7 +327,7 @@ class ImageProcessingMixin(PushToHubMixin):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
image_processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
if os.path.isfile(pretrained_model_name_or_path):
resolved_image_processor_file = pretrained_model_name_or_path
is_local = True
@@ -332,7 +335,7 @@ class ImageProcessingMixin(PushToHubMixin):
image_processor_file = pretrained_model_name_or_path
resolved_image_processor_file = download_url(pretrained_model_name_or_path)
else:
image_processor_file = IMAGE_PROCESSOR_NAME
image_processor_file = image_processor_filename
try:
# Load from local folder or from cache or download from model Hub and cache
resolved_image_processor_file = cached_file(
@@ -358,7 +361,7 @@ class ImageProcessingMixin(PushToHubMixin):
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a {IMAGE_PROCESSOR_NAME} file"
f" directory containing a {image_processor_filename} file"
)
try:

View File

@@ -503,7 +503,7 @@ def load_state_dict(
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
@@ -652,36 +652,6 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
renamed_keys = {}
renamed_gamma = {}
renamed_beta = {}
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
for key in state_dict.keys():
new_key = None
if "gamma" in key:
# We add only the first key as an example
new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
@@ -812,46 +782,7 @@ def _load_state_dict_into_meta_model(
error_msgs = []
old_keys = []
new_keys = []
renamed_gamma = {}
renamed_beta = {}
is_quantized = hf_quantizer is not None
warning_msg = f"This model {type(model)}"
for key in state_dict.keys():
new_key = None
if "gamma" in key:
# We add only the first key as an example
new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
# To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary.
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
new_key = key.replace("weight_g", "parametrizations.weight.original0")
if "weight_v" in key:
new_key = key.replace("weight_v", "parametrizations.weight.original1")
else:
if "parametrizations.weight.original0" in key:
new_key = key.replace("parametrizations.weight.original0", "weight_g")
if "parametrizations.weight.original1" in key:
new_key = key.replace("parametrizations.weight.original1", "weight_v")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
@@ -2888,6 +2819,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for ignore_key in self._keys_to_ignore_on_save:
if ignore_key in state_dict.keys():
del state_dict[ignore_key]
# Rename state_dict keys before saving to file. Do nothing unless overriden in a particular model.
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
state_dict = self._fix_state_dict_keys_on_save(state_dict)
if safe_serialization:
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
@@ -4010,7 +3946,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
with safe_open(resolved_archive_file, framework="pt") as f:
metadata = f.metadata()
if metadata.get("format") == "pt":
if metadata is None:
# Assume it's a pytorch checkpoint (introduced for timm checkpoints)
pass
elif metadata.get("format") == "pt":
pass
elif metadata.get("format") == "tf":
from_tf = True
@@ -4375,6 +4314,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model
@staticmethod
def _fix_state_dict_key_on_load(key):
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
# to avoid logging parametrized weight norm renaming
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
return key.replace("weight_g", "parametrizations.weight.original0")
if "weight_v" in key:
return key.replace("weight_v", "parametrizations.weight.original1")
else:
if "parametrizations.weight.original0" in key:
return key.replace("parametrizations.weight.original0", "weight_g")
if "parametrizations.weight.original1" in key:
return key.replace("parametrizations.weight.original1", "weight_v")
return key
@classmethod
def _fix_state_dict_keys_on_load(cls, state_dict):
"""Fixes state dict keys by replacing legacy parameter names with their modern equivalents.
Logs if any parameters have been renamed.
"""
renamed_keys = {}
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
new_key = cls._fix_state_dict_key_on_load(key)
if new_key != key:
state_dict[new_key] = state_dict.pop(key)
# add it once for logging
if "gamma" in key and "gamma" not in renamed_keys:
renamed_keys["gamma"] = (key, new_key)
if "beta" in key and "beta" not in renamed_keys:
renamed_keys["beta"] = (key, new_key)
if renamed_keys:
warning_msg = f"A pretrained model of type `{cls.__name__}` "
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.values():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
return state_dict
@staticmethod
def _fix_state_dict_key_on_save(key):
"""
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
Do nothing by default, but can be overriden in particular models.
"""
return key
def _fix_state_dict_keys_on_save(self, state_dict):
"""
Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
"""
return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
@classmethod
def _load_pretrained_model(
cls,
@@ -4430,27 +4435,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
def _fix_key(key):
if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
# to avoid logging parametrized weight norm renaming
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
return key.replace("weight_g", "parametrizations.weight.original0")
if "weight_v" in key:
return key.replace("weight_v", "parametrizations.weight.original1")
else:
if "parametrizations.weight.original0" in key:
return key.replace("parametrizations.weight.original0", "weight_g")
if "parametrizations.weight.original1" in key:
return key.replace("parametrizations.weight.original1", "weight_v")
return key
original_loaded_keys = loaded_keys
loaded_keys = [_fix_key(key) for key in loaded_keys]
loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
@@ -4615,23 +4601,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
model_key = f"{prefix}.{model_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
model_key = ".".join(model_key.split(".")[1:])
if (
model_key in model_state_dict
@@ -4680,6 +4666,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
@@ -4688,9 +4675,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# For GGUF models `state_dict` is never set to None as the state dict is always small
if gguf_path:
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
fixed_state_dict,
start_prefix,
expected_keys,
device_map=device_map,
@@ -4709,8 +4697,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs = _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
else:
@@ -4761,6 +4750,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
@@ -4774,9 +4764,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
fixed_state_dict,
start_prefix,
expected_keys,
device_map=device_map,
@@ -4797,8 +4788,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs += _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
)
# force memory release
@@ -4930,9 +4922,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
fixed_state_dict = model._fix_state_dict_keys_on_load(state_dict)
error_msgs = _load_state_dict_into_meta_model(
model,
state_dict,
fixed_state_dict,
start_prefix,
expected_keys=expected_keys,
hf_quantizer=hf_quantizer,

View File

@@ -249,6 +249,7 @@ from . import (
time_series_transformer,
timesformer,
timm_backbone,
timm_wrapper,
trocr,
tvp,
udop,

View File

@@ -276,6 +276,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("time_series_transformer", "TimeSeriesTransformerConfig"),
("timesformer", "TimesformerConfig"),
("timm_backbone", "TimmBackboneConfig"),
("timm_wrapper", "TimmWrapperConfig"),
("trajectory_transformer", "TrajectoryTransformerConfig"),
("transfo-xl", "TransfoXLConfig"),
("trocr", "TrOCRConfig"),
@@ -599,6 +600,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("time_series_transformer", "Time Series Transformer"),
("timesformer", "TimeSformer"),
("timm_backbone", "TimmBackbone"),
("timm_wrapper", "TimmWrapperModel"),
("trajectory_transformer", "Trajectory Transformer"),
("transfo-xl", "Transformer-XL"),
("trocr", "TrOCR"),

View File

@@ -30,6 +30,8 @@ from ...utils import (
CONFIG_NAME,
IMAGE_PROCESSOR_NAME,
get_file_from_repo,
is_timm_config_dict,
is_timm_local_checkpoint,
is_torchvision_available,
is_vision_available,
logging,
@@ -137,6 +139,7 @@ else:
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
("table-transformer", ("DetrImageProcessor",)),
("timesformer", ("VideoMAEImageProcessor",)),
("timm_wrapper", ("TimmWrapperImageProcessor",)),
("tvlt", ("TvltImageProcessor",)),
("tvp", ("TvpImageProcessor",)),
("udop", ("LayoutLMv3ImageProcessor",)),
@@ -376,6 +379,8 @@ class AutoImageProcessor:
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
The name of the file in the model directory to use for the image processor config.
kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are image processor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
@@ -415,7 +420,37 @@ class AutoImageProcessor:
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
# Resolve the image processor config filename
if "image_processor_filename" in kwargs:
image_processor_filename = kwargs.pop("image_processor_filename")
elif is_timm_local_checkpoint(pretrained_model_name_or_path):
image_processor_filename = CONFIG_NAME
else:
image_processor_filename = IMAGE_PROCESSOR_NAME
# Load the image processor config
try:
# Main path for all transformers models and local TimmWrapper checkpoints
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
)
except Exception as initial_exception:
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
# load `config.json` and if it fails with some error, we raise the initial exception.
try:
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
)
except Exception:
raise initial_exception
# In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
# because only timm models have image processing in `config.json`.
if not is_timm_config_dict(config_dict):
raise initial_exception
image_processor_class = config_dict.get("image_processor_type", None)
image_processor_auto_map = None
if "AutoImageProcessor" in config_dict.get("auto_map", {}):

View File

@@ -255,6 +255,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("time_series_transformer", "TimeSeriesTransformerModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
("timm_wrapper", "TimmWrapperModel"),
("trajectory_transformer", "TrajectoryTransformerModel"),
("transfo-xl", "TransfoXLModel"),
("tvlt", "TvltModel"),
@@ -605,6 +606,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
("table-transformer", "TableTransformerModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
("timm_wrapper", "TimmWrapperModel"),
("van", "VanModel"),
("videomae", "VideoMAEModel"),
("vit", "ViTModel"),
@@ -690,6 +692,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("swiftformer", "SwiftFormerForImageClassification"),
("swin", "SwinForImageClassification"),
("swinv2", "Swinv2ForImageClassification"),
("timm_wrapper", "TimmWrapperForImageClassification"),
("van", "VanForImageClassification"),
("vit", "ViTForImageClassification"),
("vit_hybrid", "ViTHybridForImageClassification"),

View File

@@ -0,0 +1,28 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_timm_wrapper import *
from .modeling_timm_wrapper import *
from .processing_timm_wrapper import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@@ -0,0 +1,86 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for TimmWrapper models"""
from typing import Any, Dict
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class TimmWrapperConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`].
It is used to instantiate a timm model according to the specified arguments, defining the model.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
do_pooling (`bool`, *optional*, defaults to `True`):
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
Example:
```python
>>> from transformers import TimmWrapperModel
>>> # Initializing a timm model
>>> model = TimmWrapperModel.from_pretrained("timm/resnet18.a1_in1k")
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "timm_wrapper"
def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs):
self.initializer_range = initializer_range
self.do_pooling = do_pooling
super().__init__(**kwargs)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config
# and to avoid duplicate attributes
num_labels_in_kwargs = kwargs.pop("num_labels", None)
num_labels_in_dict = config_dict.pop("num_classes", None)
# passed num_labels has priority over num_classes in config_dict
kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict
# pop num_classes from "pretrained_cfg",
# it is not necessary to have it, only root one is used in timm
if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]:
config_dict["pretrained_cfg"].pop("num_classes", None)
return super().from_dict(config_dict, **kwargs)
def to_dict(self) -> Dict[str, Any]:
output = super().to_dict()
output["num_classes"] = self.num_labels
return output
__all__ = ["TimmWrapperConfig"]

View File

@@ -0,0 +1,138 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Optional, Tuple, Union
import torch
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import to_pil_image
from ...image_utils import ImageInput, make_list_of_images
from ...utils import TensorType, logging, requires_backends
from ...utils.import_utils import is_timm_available, is_torch_available
if is_timm_available():
import timm
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class TimmWrapperImageProcessor(BaseImageProcessor):
"""
Wrapper class for timm models to be used within transformers.
Args:
pretrained_cfg (`Dict[str, Any]`):
The configuration of the pretrained model used to resolve evaluation and
training transforms.
architecture (`Optional[str]`, *optional*):
Name of the architecture of the model.
"""
main_input_name = "pixel_values"
def __init__(
self,
pretrained_cfg: Dict[str, Any],
architecture: Optional[str] = None,
**kwargs,
):
requires_backends(self, "timm")
super().__init__(architecture=architecture)
self.data_config = timm.data.resolve_data_config(pretrained_cfg, model=None, verbose=False)
self.val_transforms = timm.data.create_transform(**self.data_config, is_training=False)
# useful for training, see examples/pytorch/image-classification/run_image_classification.py
self.train_transforms = timm.data.create_transform(**self.data_config, is_training=True)
# If `ToTensor` is in the transforms, then the input should be numpy array or PIL image.
# Otherwise, the input can be a tensor. In later timm versions, `MaybeToTensor` is used
# which can handle both numpy arrays / PIL images and tensors.
self._not_supports_tensor_input = any(
transform.__class__.__name__ == "ToTensor" for transform in self.val_transforms.transforms
)
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
"""
output = super().to_dict()
output.pop("train_transforms", None)
output.pop("val_transforms", None)
output.pop("_not_supports_tensor_input", None)
return output
@classmethod
def get_image_processor_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Get the image processor dict for the model.
"""
image_processor_filename = kwargs.pop("image_processor_filename", "config.json")
return super().get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
)
def preprocess(
self,
images: ImageInput,
return_tensors: Optional[Union[str, TensorType]] = "pt",
) -> BatchFeature:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return.
"""
if return_tensors != "pt":
raise ValueError(f"return_tensors for TimmWrapperImageProcessor must be 'pt', but got {return_tensors}")
if self._not_supports_tensor_input and isinstance(images, torch.Tensor):
images = images.cpu().numpy()
# If the input is a torch tensor, then no conversion is needed
# Otherwise, we need to pass in a list of PIL images
if isinstance(images, torch.Tensor):
images = self.val_transforms(images)
# Add batch dimension if a single image
images = images.unsqueeze(0) if images.ndim == 3 else images
else:
images = make_list_of_images(images)
images = [to_pil_image(image) for image in images]
images = torch.stack([self.val_transforms(image) for image in images])
return BatchFeature({"pixel_values": images}, tensor_type=return_tensors)
def save_pretrained(self, *args, **kwargs):
# disable it to make checkpoint the same as in `timm` library.
logger.warning_once(
"The `save_pretrained` method is disabled for TimmWrapperImageProcessor. "
"The image processor configuration is saved directly in `config.json` when "
"`save_pretrained` is called for saving the model."
)
__all__ = ["TimmWrapperImageProcessor"]

View File

@@ -0,0 +1,363 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...modeling_outputs import ImageClassifierOutput, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings_to_model_forward,
is_timm_available,
replace_return_docstrings,
requires_backends,
)
from .configuration_timm_wrapper import TimmWrapperConfig
if is_timm_available():
import timm
@dataclass
class TimmWrapperModelOutput(ModelOutput):
"""
Output class for models TimmWrapperModel, containing the last hidden states, an optional pooled output,
and optional hidden states.
Args:
last_hidden_state (`torch.FloatTensor`):
The last hidden state of the model, output before applying the classification head.
pooler_output (`torch.FloatTensor`, *optional*):
The pooled output derived from the last hidden state, if applicable.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
A tuple containing the intermediate hidden states of the model at the output of each layer or specified layers.
Returned if `output_hidden_states=True` is set or if `config.output_hidden_states=True`.
attentions (`tuple(torch.FloatTensor)`, *optional*):
A tuple containing the intermediate attention weights of the model at the output of each layer.
Returned if `output_attentions=True` is set or if `config.output_attentions=True`.
Note: Currently, Timm models do not support attentions output.
"""
last_hidden_state: torch.FloatTensor
pooler_output: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
TIMM_WRAPPER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`TimmWrapperImageProcessor.preprocess`]
for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
**kwargs:
Additional keyword arguments passed along to the `timm` model forward.
"""
class TimmWrapperPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
config_class = TimmWrapperConfig
_no_split_modules = []
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision", "timm"])
super().__init__(*args, **kwargs)
@staticmethod
def _fix_state_dict_key_on_load(key):
"""
Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
We don't want this behavior for timm wrapped models. Instead, this method adds a
"timm_model." prefix to enable loading official timm Hub checkpoints.
"""
if "timm_model." not in key:
return f"timm_model.{key}"
return key
def _fix_state_dict_key_on_save(self, key):
"""
Overrides original method to remove "timm_model." prefix from state_dict keys.
Makes the saved checkpoint compatible with the `timm` library.
"""
return key.replace("timm_model.", "")
def load_state_dict(self, state_dict, *args, **kwargs):
"""
Override original method to fix state_dict keys on load for cases when weights are loaded
without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint).
"""
state_dict = self._fix_state_dict_keys_on_load(state_dict)
return super().load_state_dict(state_dict, *args, **kwargs)
def _init_weights(self, module):
"""
Initialize weights function to properly initialize Linear layer weights.
Since model architectures may vary, we assume only the classifier requires
initialization, while all other weights should be loaded from the checkpoint.
"""
if isinstance(module, (nn.Linear)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
class TimmWrapperModel(TimmWrapperPreTrainedModel):
"""
Wrapper class for timm models to be used in transformers.
"""
def __init__(self, config: TimmWrapperConfig):
super().__init__(config)
# using num_classes=0 to avoid creating classification head
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0)
self.post_init()
@add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TimmWrapperModelOutput, config_class=TimmWrapperConfig)
def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[Union[bool, List[int]]] = None,
return_dict: Optional[bool] = None,
do_pooling: Optional[bool] = None,
**kwargs,
) -> Union[TimmWrapperModelOutput, Tuple[Tensor, ...]]:
r"""
do_pooling (`bool`, *optional*):
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the
`do_pooling` value from the config is used.
Returns:
Examples:
```python
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModel, AutoImageProcessor
>>> # Load image
>>> image = Image.open(urlopen(
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))
>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModel.from_pretrained(checkpoint).eval()
>>> # Preprocess image
>>> inputs = image_processor(image)
>>> # Forward pass
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # Get pooled output
>>> pooled_output = outputs.pooler_output
>>> # Get last hidden state
>>> last_hidden_state = outputs.last_hidden_state
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
do_pooling = do_pooling if do_pooling is not None else self.config.do_pooling
if output_attentions:
raise ValueError("Cannot set `output_attentions` for timm models.")
if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"):
raise ValueError(
"The 'output_hidden_states' option cannot be set for this timm model. "
"To enable this feature, the 'forward_intermediates' method must be implemented "
"in the timm model (available in timm versions > 1.*). Please consider using a "
"different architecture or updating the timm package to a compatible version."
)
pixel_values = pixel_values.to(self.device, self.dtype)
if output_hidden_states:
# to enable hidden states selection
if isinstance(output_hidden_states, (list, tuple)):
kwargs["indices"] = output_hidden_states
last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs)
else:
last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs)
hidden_states = None
if do_pooling:
# classification head is not created, applying pooling only
pooler_output = self.timm_model.forward_head(last_hidden_state)
else:
pooler_output = None
if not return_dict:
outputs = (last_hidden_state, pooler_output, hidden_states)
outputs = tuple(output for output in outputs if output is not None)
return outputs
return TimmWrapperModelOutput(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=hidden_states,
)
class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
"""
Wrapper class for timm models to be used in transformers for image classification.
"""
def __init__(self, config: TimmWrapperConfig):
super().__init__(config)
if config.num_labels == 0:
raise ValueError(
"You are trying to load weights into `TimmWrapperForImageClassification` from a checkpoint with no classifier head. "
"Please specify the number of classes, e.g. `model = TimmWrapperForImageClassification.from_pretrained(..., num_labels=10)`, "
"or use `TimmWrapperModel` for feature extraction."
)
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels)
self.num_labels = config.num_labels
self.post_init()
@add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=TimmWrapperConfig)
def forward(
self,
pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[Union[bool, List[int]]] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[ImageClassifierOutput, Tuple[Tensor, ...]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor
>>> # Load image
>>> image = Image.open(urlopen(
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))
>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
>>> # Preprocess image
>>> inputs = image_processor(image)
>>> # Forward pass
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # Get top 5 predictions
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if output_attentions:
raise ValueError("Cannot set `output_attentions` for timm models.")
if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"):
raise ValueError(
"The 'output_hidden_states' option cannot be set for this timm model. "
"To enable this feature, the 'forward_intermediates' method must be implemented "
"in the timm model (available in timm versions > 1.*). Please consider using a "
"different architecture or updating the timm package to a compatible version."
)
pixel_values = pixel_values.to(self.device, self.dtype)
if output_hidden_states:
# to enable hidden states selection
if isinstance(output_hidden_states, (list, tuple)):
kwargs["indices"] = output_hidden_states
last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs)
logits = self.timm_model.forward_head(last_hidden_state)
else:
logits = self.timm_model(pixel_values, **kwargs)
hidden_states = None
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
outputs = (loss, logits, hidden_states)
outputs = tuple(output for output in outputs if output is not None)
return outputs
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=hidden_states,
)
__all__ = ["TimmWrapperPreTrainedModel", "TimmWrapperModel", "TimmWrapperForImageClassification"]

View File

@@ -55,6 +55,8 @@ from .generic import (
is_tensor,
is_tf_symbolic_tensor,
is_tf_tensor,
is_timm_config_dict,
is_timm_local_checkpoint,
is_torch_device,
is_torch_dtype,
is_torch_tensor,

View File

@@ -9043,6 +9043,27 @@ class TimmBackbone(metaclass=DummyObject):
requires_backends(self, ["torch"])
class TimmWrapperForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TimmWrapperModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TimmWrapperPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TrOCRForCausalLM(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -0,0 +1,9 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class TimmWrapperImageProcessor(metaclass=DummyObject):
_backends = ["timm", "torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "torchvision"])

View File

@@ -16,6 +16,8 @@ Generic utilities
"""
import inspect
import json
import os
import tempfile
import warnings
from collections import OrderedDict, UserDict
@@ -24,7 +26,7 @@ from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass
from enum import Enum
from functools import partial, wraps
from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict
import numpy as np
from packaging import version
@@ -867,3 +869,36 @@ class LossKwargs(TypedDict, total=False):
"""
num_items_in_batch: Optional[int]
def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool:
"""Checks whether a config dict is a timm config dict."""
return "pretrained_cfg" in config_dict
def is_timm_local_checkpoint(pretrained_model_path: str) -> bool:
"""
Checks whether a checkpoint is a timm model checkpoint.
"""
if pretrained_model_path is None:
return False
# in case it's Path, not str
pretrained_model_path = str(pretrained_model_path)
is_file = os.path.isfile(pretrained_model_path)
is_dir = os.path.isdir(pretrained_model_path)
# pretrained_model_path is a file
if is_file and pretrained_model_path.endswith(".json"):
with open(pretrained_model_path, "r") as f:
config_dict = json.load(f)
return is_timm_config_dict(config_dict)
# pretrained_model_path is a directory with a config.json
if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")):
with open(os.path.join(pretrained_model_path, "config.json"), "r") as f:
config_dict = json.load(f)
return is_timm_config_dict(config_dict)
return False

View File

View File

@@ -0,0 +1,103 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_torchvision, require_vision
from transformers.utils import is_torch_available, is_vision_available
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import TimmWrapperConfig, TimmWrapperImageProcessor
@require_torch
@require_vision
@require_torchvision
class TimmWrapperImageProcessingTest(unittest.TestCase):
image_processing_class = TimmWrapperImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.temp_dir = tempfile.TemporaryDirectory()
config = TimmWrapperConfig.from_pretrained("timm/resnet18.a1_in1k")
config.save_pretrained(self.temp_dir.name)
def tearDown(self):
self.temp_dir.cleanup()
def test_load_from_hub(self):
image_processor = TimmWrapperImageProcessor.from_pretrained("timm/resnet18.a1_in1k")
self.assertIsInstance(image_processor, TimmWrapperImageProcessor)
def test_load_from_local_dir(self):
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
self.assertIsInstance(image_processor, TimmWrapperImageProcessor)
def test_image_processor_properties(self):
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
self.assertTrue(hasattr(image_processor, "data_config"))
self.assertTrue(hasattr(image_processor, "val_transforms"))
self.assertTrue(hasattr(image_processor, "train_transforms"))
def test_image_processor_call_numpy(self):
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
single_image = np.random.randint(256, size=(256, 256, 3), dtype=np.uint8)
batch_images = [single_image, single_image, single_image]
# single image
pixel_values = image_processor(single_image).pixel_values
self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
# batch images
pixel_values = image_processor(batch_images).pixel_values
self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
def test_image_processor_call_pil(self):
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
single_image = Image.fromarray(np.random.randint(256, size=(256, 256, 3), dtype=np.uint8))
batch_images = [single_image, single_image, single_image]
# single image
pixel_values = image_processor(single_image).pixel_values
self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
# batch images
pixel_values = image_processor(batch_images).pixel_values
self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
def test_image_processor_call_tensor(self):
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
single_image = torch.from_numpy(np.random.randint(256, size=(3, 256, 256), dtype=np.uint8)).float()
batch_images = [single_image, single_image, single_image]
# single image
pixel_values = image_processor(single_image).pixel_values
self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
# batch images
pixel_values = image_processor(batch_images).pixel_values
self.assertEqual(pixel_values.shape, (3, 3, 224, 224))

View File

@@ -0,0 +1,366 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import tempfile
import unittest
from transformers.testing_utils import (
require_bitsandbytes,
require_timm,
require_torch,
require_vision,
slow,
torch_device,
)
from transformers.utils.import_utils import is_timm_available, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import TimmWrapperConfig, TimmWrapperForImageClassification, TimmWrapperModel
if is_timm_available():
import timm
if is_vision_available():
from PIL import Image
from transformers import TimmWrapperImageProcessor
class TimmWrapperModelTester:
def __init__(
self,
parent,
model_name="timm/resnet18.a1_in1k",
batch_size=3,
image_size=32,
num_channels=3,
is_training=True,
):
self.parent = parent
self.model_name = model_name
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.is_training = is_training
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return TimmWrapperConfig.from_pretrained(self.model_name)
def create_and_check_model(self, config, pixel_values):
model = TimmWrapperModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
self.parent.assertEqual(
result.feature_map[-1].shape,
(self.batch_size, model.channels[-1], 14, 14),
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
@require_timm
class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (TimmWrapperModel, TimmWrapperForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"image-feature-extraction": TimmWrapperModel, "image-classification": TimmWrapperForImageClassification}
if is_torch_available()
else {}
)
test_resize_embeddings = False
test_head_masking = False
test_pruning = False
has_attentions = False
test_model_parallel = False
def setUp(self):
self.config_class = TimmWrapperConfig
self.model_tester = TimmWrapperModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=self.config_class,
has_text_modality=False,
common_properties=[],
model_name="timm/resnet18.a1_in1k",
)
def test_config(self):
self.config_tester.run_common_tests()
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
# check all hidden states
with torch.no_grad():
outputs = model(**inputs_dict, output_hidden_states=True)
self.assertTrue(
len(outputs.hidden_states) == 5, f"expected 5 hidden states, but got {len(outputs.hidden_states)}"
)
expected_shapes = [[16, 16], [8, 8], [4, 4], [2, 2], [1, 1]]
resulted_shapes = [list(h.shape[2:]) for h in outputs.hidden_states]
self.assertListEqual(expected_shapes, resulted_shapes)
# check we can select hidden states by indices
with torch.no_grad():
outputs = model(**inputs_dict, output_hidden_states=[-2, -1])
self.assertTrue(
len(outputs.hidden_states) == 2, f"expected 2 hidden states, but got {len(outputs.hidden_states)}"
)
expected_shapes = [[2, 2], [1, 1]]
resulted_shapes = [list(h.shape[2:]) for h in outputs.hidden_states]
self.assertListEqual(expected_shapes, resulted_shapes)
@unittest.skip(reason="TimmWrapper models doesn't have inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="TimmWrapper models doesn't have inputs_embeds")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="TimmWrapper doesn't support output_attentions=True.")
def test_torchscript_output_attentions(self):
pass
@unittest.skip(reason="TimmWrapper doesn't support this.")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="TimmWrapper initialization is managed on the timm side")
def test_initialization(self):
pass
@unittest.skip(reason="Need to use a timm model and there is no tiny model available.")
def test_model_is_small(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_do_pooling_option(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.do_pooling = False
model = TimmWrapperModel._from_config(config)
# check there is no pooling
with torch.no_grad():
output = model(**inputs_dict)
self.assertIsNone(output.pooler_output)
# check there is pooler output
with torch.no_grad():
output = model(**inputs_dict, do_pooling=True)
self.assertIsNotNone(output.pooler_output)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
@require_timm
@require_vision
class TimmWrapperModelIntegrationTest(unittest.TestCase):
# some popular ones
model_names_to_test = [
"vit_small_patch16_384.augreg_in21k_ft_in1k",
"resnet50.a1_in1k",
"tf_mobilenetv3_large_minimal_100.in1k",
"swin_tiny_patch4_window7_224.ms_in1k",
"ese_vovnet19b_dw.ra_in1k",
"hrnet_w18.ms_aug_in1k",
]
@slow
def test_inference_image_classification_head(self):
checkpoint = "timm/resnet18.a1_in1k"
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map=torch_device).eval()
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
# verify the shape and logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_label = 281 # tabby cat
self.assertEqual(torch.argmax(outputs.logits).item(), expected_label)
expected_slice = torch.tensor([-11.2618, -9.6192, -10.3205]).to(torch_device)
resulted_slice = outputs.logits[0, :3]
is_close = torch.allclose(resulted_slice, expected_slice, atol=1e-3)
self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}")
@slow
@require_bitsandbytes
def test_inference_image_classification_quantized(self):
from transformers import BitsAndBytesConfig
checkpoint = "timm/vit_small_patch16_384.augreg_in21k_ft_in1k"
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = TimmWrapperForImageClassification.from_pretrained(
checkpoint, quantization_config=quantization_config, device_map=torch_device
).eval()
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
# verify the shape and logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_label = 281 # tabby cat
self.assertEqual(torch.argmax(outputs.logits).item(), expected_label)
expected_slice = torch.tensor([-2.4043, 1.4492, -0.5127]).to(outputs.logits.dtype)
resulted_slice = outputs.logits[0, :3].cpu()
is_close = torch.allclose(resulted_slice, expected_slice, atol=0.1)
self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}")
@slow
def test_transformers_model_for_classification_is_equivalent_to_timm(self):
# check that wrapper logits are the same as timm model logits
image = prepare_img()
for model_name in self.model_names_to_test:
checkpoint = f"timm/{model_name}"
with self.subTest(msg=model_name):
# prepare inputs
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
pixel_values = image_processor(images=image).pixel_values.to(torch_device)
# load models
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map=torch_device).eval()
timm_model = timm.create_model(model_name, pretrained=True).to(torch_device).eval()
with torch.inference_mode():
outputs = model(pixel_values)
timm_outputs = timm_model(pixel_values)
# check shape is the same
self.assertEqual(outputs.logits.shape, timm_outputs.shape)
# check logits are the same
diff = (outputs.logits - timm_outputs).max().item()
self.assertLess(diff, 1e-4)
@slow
def test_transformers_model_is_equivalent_to_timm(self):
# check that wrapper logits are the same as timm model logits
image = prepare_img()
models_to_test = ["vit_small_patch16_224.dino"] + self.model_names_to_test
for model_name in models_to_test:
checkpoint = f"timm/{model_name}"
with self.subTest(msg=model_name):
# prepare inputs
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
pixel_values = image_processor(images=image).pixel_values.to(torch_device)
# load models
model = TimmWrapperModel.from_pretrained(checkpoint, device_map=torch_device).eval()
timm_model = timm.create_model(model_name, pretrained=True, num_classes=0).to(torch_device).eval()
with torch.inference_mode():
outputs = model(pixel_values)
timm_outputs = timm_model(pixel_values)
# check shape is the same
self.assertEqual(outputs.pooler_output.shape, timm_outputs.shape)
# check logits are the same
diff = (outputs.pooler_output - timm_outputs).max().item()
self.assertLess(diff, 1e-4)
@slow
def test_save_load_to_timm(self):
# test that timm model can be loaded to transformers, saved and then loaded back into timm
model = TimmWrapperForImageClassification.from_pretrained(
"timm/resnet18.a1_in1k", num_labels=10, ignore_mismatched_sizes=True
)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# there is no direct way to load timm model from folder, use the same config + path to weights
timm_model = timm.create_model(
"resnet18", num_classes=10, checkpoint_path=f"{tmpdirname}/model.safetensors"
)
# check that all weights are the same after reload
different_weights = []
for (name1, param1), (name2, param2) in zip(
model.timm_model.named_parameters(), timm_model.named_parameters()
):
if param1.shape != param2.shape or not torch.equal(param1, param2):
different_weights.append((name1, name2))
if different_weights:
self.fail(f"Found different weights after reloading: {different_weights}")

View File

@@ -3443,6 +3443,7 @@ class ModelTesterMixin:
"Data2VecAudioForSequenceClassification",
"UniSpeechForSequenceClassification",
"PvtForImageClassification",
"TimmWrapperForImageClassification",
]
special_param_names = [
r"^bit\.",
@@ -3463,6 +3464,7 @@ class ModelTesterMixin:
r"^swiftformer\.",
r"^swinv2\.",
r"^transformers\.models\.swiftformer\.",
r"^timm_model\.",
r"^unispeech\.",
r"^unispeech_sat\.",
r"^vision_model\.",

View File

@@ -41,6 +41,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"RagConfig",
"SpeechEncoderDecoderConfig",
"TimmBackboneConfig",
"TimmWrapperConfig",
"VisionEncoderDecoderConfig",
"VisionTextDualEncoderConfig",
"LlamaConfig",