Add TimmWrapper (#34564)
* Add files * Init * Add TimmWrapperModel * Fix up * Some fixes * Fix up * Remove old file * Sort out import orders * Fix some model loading * Compatible with pipeline and trainer * Fix up * Delete test_timm_model_1/config.json * Remove accidentally commited files * Delete src/transformers/models/modeling_timm_wrapper.py * Remove empty imports; fix transformations applied * Tidy up * Add image classifcation model to special cases * Create pretrained model; enable device_map='auto' * Enable most tests; fix init order * Sort imports * [run-slow] timm_wrapper * Pass num_classes into timm.create_model * Remove train transforms from image processor * Update timm creation with pretrained=False * Fix gamma/beta issue for timm models * Fixing gamma and beta renaming for timm models * Simplify config and model creation * Remove attn_implementation diff * Fixup * Docstrings * Fix warning msg text according to test case * Fix device_map auto * Set dtype and device for pixel_values in forward * Enable output hidden states * Enable tests for hidden_states and model parallel * Remove default scriptable arg * Refactor inner model * Update timm version * Fix _find_mismatched_keys function * Change inheritance for Classification model (fix weights loading with device_map) * Minor bugfix * Disable save pretrained for image processor * Rename hook method for loaded keys correction * Rename state dict keys on save, remove `timm_model` prefix, make checkpoint compatible with `timm` * Managing num_labels <-> num_classes attributes * Enable loading checkpoints in Trainer to resume training * Update error message for output_hidden_states * Add output hidden states test * Decouple base and classification models * Add more test cases * Add save-load-to-timm test * Fix test name * Fixup * Add do_pooling * Add test for do_pooling * Fix doc * Add tests for TimmWrapperModel * Add validation for `num_classes=0` in timm config + test for DINO checkpoint * Adjust atol for test * Fix docs * dev-ci * dev-ci * Add tests for image processor * Update docs * Update init to new format * Update docs in configuration * Fix some docs in image processor * Improve docs for modeling * fix for is_timm_checkpoint * Update code examples * Fix header * Fix typehint * Increase tolerance a bit * Fix Path * Fixing model parallel tests * Disable "parallel" tests * Add comment for metadata * Refactor AutoImageProcessor for timm wrapper loading * Remove custom test_model_outputs_equivalence * Add require_timm decorator * Fix comment * Make image processor work with older timm versions and tensor input * Save config instead of whole model in image processor tests * Add docstring for `image_processor_filename` * Sanitize kwargs for timm image processor * Fix doc style * Update check for tensor input * Update normalize * Remove _load_timm_model function --------- Co-authored-by: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bcc50cc7ce
commit
5fcf6286bf
@@ -705,6 +705,8 @@
|
||||
title: Swin2SR
|
||||
- local: model_doc/table-transformer
|
||||
title: Table Transformer
|
||||
- local: model_doc/timm_wrapper
|
||||
title: Timm Wrapper
|
||||
- local: model_doc/upernet
|
||||
title: UperNet
|
||||
- local: model_doc/van
|
||||
|
||||
@@ -321,6 +321,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ |
|
||||
| [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ |
|
||||
| [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ |
|
||||
| [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ |
|
||||
| [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ |
|
||||
| [Transformer-XL](model_doc/transfo-xl) | ✅ | ✅ | ❌ |
|
||||
| [TrOCR](model_doc/trocr) | ✅ | ❌ | ❌ |
|
||||
|
||||
67
docs/source/en/model_doc/timm_wrapper.md
Normal file
67
docs/source/en/model_doc/timm_wrapper.md
Normal file
@@ -0,0 +1,67 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# TimmWrapper
|
||||
|
||||
## Overview
|
||||
|
||||
Helper class to enable loading timm models to be used with the transformers library and its autoclasses.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from urllib.request import urlopen
|
||||
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor
|
||||
|
||||
>>> # Load image
|
||||
>>> image = Image.open(urlopen(
|
||||
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
|
||||
... ))
|
||||
|
||||
>>> # Load model and image processor
|
||||
>>> checkpoint = "timm/resnet50.a1_in1k"
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
|
||||
|
||||
>>> # Preprocess image
|
||||
>>> inputs = image_processor(image)
|
||||
|
||||
>>> # Forward pass
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
|
||||
>>> # Get top 5 predictions
|
||||
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
|
||||
```
|
||||
|
||||
## TimmWrapperConfig
|
||||
|
||||
[[autodoc]] TimmWrapperConfig
|
||||
|
||||
## TimmWrapperImageProcessor
|
||||
|
||||
[[autodoc]] TimmWrapperImageProcessor
|
||||
- preprocess
|
||||
|
||||
## TimmWrapperModel
|
||||
|
||||
[[autodoc]] TimmWrapperModel
|
||||
- forward
|
||||
|
||||
## TimmWrapperForImageClassification
|
||||
|
||||
[[autodoc]] TimmWrapperForImageClassification
|
||||
- forward
|
||||
@@ -42,6 +42,7 @@ from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoModelForImageClassification,
|
||||
HfArgumentParser,
|
||||
TimmWrapperImageProcessor,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
@@ -329,31 +330,36 @@ def main():
|
||||
)
|
||||
|
||||
# Define torchvision transforms to be applied to each image.
|
||||
if "shortest_edge" in image_processor.size:
|
||||
size = image_processor.size["shortest_edge"]
|
||||
if isinstance(image_processor, TimmWrapperImageProcessor):
|
||||
_train_transforms = image_processor.train_transforms
|
||||
_val_transforms = image_processor.val_transforms
|
||||
else:
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
normalize = (
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
|
||||
else Lambda(lambda x: x)
|
||||
)
|
||||
_train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(size),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
_val_transforms = Compose(
|
||||
[
|
||||
Resize(size),
|
||||
CenterCrop(size),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
if "shortest_edge" in image_processor.size:
|
||||
size = image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
|
||||
# Create normalization transform
|
||||
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std"):
|
||||
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
else:
|
||||
normalize = Lambda(lambda x: x)
|
||||
_train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(size),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
_val_transforms = Compose(
|
||||
[
|
||||
Resize(size),
|
||||
CenterCrop(size),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
|
||||
def train_transforms(example_batch):
|
||||
"""Apply _train_transforms across a batch."""
|
||||
|
||||
@@ -782,6 +782,7 @@ _import_structure = {
|
||||
"models.time_series_transformer": ["TimeSeriesTransformerConfig"],
|
||||
"models.timesformer": ["TimesformerConfig"],
|
||||
"models.timm_backbone": ["TimmBackboneConfig"],
|
||||
"models.timm_wrapper": ["TimmWrapperConfig"],
|
||||
"models.trocr": [
|
||||
"TrOCRConfig",
|
||||
"TrOCRProcessor",
|
||||
@@ -1272,6 +1273,18 @@ else:
|
||||
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
|
||||
_import_structure["models.vit"].append("ViTImageProcessorFast")
|
||||
|
||||
try:
|
||||
if not is_torchvision_available() and not is_timm_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_timm_and_torchvision_objects
|
||||
|
||||
_import_structure["utils.dummy_timm_and_torchvision_objects"] = [
|
||||
name for name in dir(dummy_timm_and_torchvision_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["models.timm_wrapper"].extend(["TimmWrapperImageProcessor"])
|
||||
|
||||
# PyTorch-backed objects
|
||||
try:
|
||||
if not is_torch_available():
|
||||
@@ -3532,6 +3545,9 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["models.timm_backbone"].extend(["TimmBackbone"])
|
||||
_import_structure["models.timm_wrapper"].extend(
|
||||
["TimmWrapperForImageClassification", "TimmWrapperModel", "TimmWrapperPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.trocr"].extend(
|
||||
[
|
||||
"TrOCRForCausalLM",
|
||||
@@ -5734,6 +5750,7 @@ if TYPE_CHECKING:
|
||||
TimesformerConfig,
|
||||
)
|
||||
from .models.timm_backbone import TimmBackboneConfig
|
||||
from .models.timm_wrapper import TimmWrapperConfig
|
||||
from .models.trocr import (
|
||||
TrOCRConfig,
|
||||
TrOCRProcessor,
|
||||
@@ -6227,6 +6244,14 @@ if TYPE_CHECKING:
|
||||
from .models.rt_detr import RTDetrImageProcessorFast
|
||||
from .models.vit import ViTImageProcessorFast
|
||||
|
||||
try:
|
||||
if not is_torchvision_available() and not is_timm_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_timm_and_torchvision_objects import *
|
||||
else:
|
||||
from .models.timm_wrapper import TimmWrapperImageProcessor
|
||||
|
||||
# Modeling
|
||||
try:
|
||||
if not is_torch_available():
|
||||
@@ -8037,6 +8062,11 @@ if TYPE_CHECKING:
|
||||
TimesformerPreTrainedModel,
|
||||
)
|
||||
from .models.timm_backbone import TimmBackbone
|
||||
from .models.timm_wrapper import (
|
||||
TimmWrapperForImageClassification,
|
||||
TimmWrapperModel,
|
||||
TimmWrapperPreTrainedModel,
|
||||
)
|
||||
from .models.trocr import (
|
||||
TrOCRForCausalLM,
|
||||
TrOCRPreTrainedModel,
|
||||
|
||||
@@ -37,6 +37,7 @@ from .utils import (
|
||||
download_url,
|
||||
extract_commit_hash,
|
||||
is_remote_url,
|
||||
is_timm_config_dict,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
@@ -702,6 +703,11 @@ class PretrainedConfig(PushToHubMixin):
|
||||
config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
|
||||
config_dict["custom_pipelines"], pretrained_model_name_or_path
|
||||
)
|
||||
|
||||
# timm models are not saved with the model_type in the config file
|
||||
if "model_type" not in config_dict and is_timm_config_dict(config_dict):
|
||||
config_dict["model_type"] = "timm_wrapper"
|
||||
|
||||
return config_dict, kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -285,6 +285,8 @@ class ImageProcessingMixin(PushToHubMixin):
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
|
||||
The name of the file in the model directory to use for the image processor config.
|
||||
|
||||
Returns:
|
||||
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
|
||||
@@ -298,6 +300,7 @@ class ImageProcessingMixin(PushToHubMixin):
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
|
||||
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
@@ -324,7 +327,7 @@ class ImageProcessingMixin(PushToHubMixin):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
image_processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
|
||||
image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
resolved_image_processor_file = pretrained_model_name_or_path
|
||||
is_local = True
|
||||
@@ -332,7 +335,7 @@ class ImageProcessingMixin(PushToHubMixin):
|
||||
image_processor_file = pretrained_model_name_or_path
|
||||
resolved_image_processor_file = download_url(pretrained_model_name_or_path)
|
||||
else:
|
||||
image_processor_file = IMAGE_PROCESSOR_NAME
|
||||
image_processor_file = image_processor_filename
|
||||
try:
|
||||
# Load from local folder or from cache or download from model Hub and cache
|
||||
resolved_image_processor_file = cached_file(
|
||||
@@ -358,7 +361,7 @@ class ImageProcessingMixin(PushToHubMixin):
|
||||
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
||||
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a {IMAGE_PROCESSOR_NAME} file"
|
||||
f" directory containing a {image_processor_filename} file"
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -503,7 +503,7 @@ def load_state_dict(
|
||||
# Check format of the archive
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
||||
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
||||
raise OSError(
|
||||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
||||
"you save your model with the `save_pretrained` method."
|
||||
@@ -652,36 +652,6 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
renamed_keys = {}
|
||||
renamed_gamma = {}
|
||||
renamed_beta = {}
|
||||
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
# We add only the first key as an example
|
||||
new_key = key.replace("gamma", "weight")
|
||||
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
|
||||
if "beta" in key:
|
||||
# We add only the first key as an example
|
||||
new_key = key.replace("beta", "bias")
|
||||
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
renamed_keys = {**renamed_gamma, **renamed_beta}
|
||||
if renamed_keys:
|
||||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
|
||||
for old_key, new_key in renamed_keys.items():
|
||||
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
|
||||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
|
||||
logger.info_once(warning_msg)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
@@ -812,46 +782,7 @@ def _load_state_dict_into_meta_model(
|
||||
|
||||
error_msgs = []
|
||||
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
renamed_gamma = {}
|
||||
renamed_beta = {}
|
||||
is_quantized = hf_quantizer is not None
|
||||
warning_msg = f"This model {type(model)}"
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
# We add only the first key as an example
|
||||
new_key = key.replace("gamma", "weight")
|
||||
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
|
||||
if "beta" in key:
|
||||
# We add only the first key as an example
|
||||
new_key = key.replace("beta", "bias")
|
||||
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
|
||||
|
||||
# To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary.
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
if "weight_g" in key:
|
||||
new_key = key.replace("weight_g", "parametrizations.weight.original0")
|
||||
if "weight_v" in key:
|
||||
new_key = key.replace("weight_v", "parametrizations.weight.original1")
|
||||
else:
|
||||
if "parametrizations.weight.original0" in key:
|
||||
new_key = key.replace("parametrizations.weight.original0", "weight_g")
|
||||
if "parametrizations.weight.original1" in key:
|
||||
new_key = key.replace("parametrizations.weight.original1", "weight_v")
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
renamed_keys = {**renamed_gamma, **renamed_beta}
|
||||
if renamed_keys:
|
||||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
|
||||
for old_key, new_key in renamed_keys.items():
|
||||
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
|
||||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
|
||||
logger.info_once(warning_msg)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
|
||||
@@ -2888,6 +2819,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
for ignore_key in self._keys_to_ignore_on_save:
|
||||
if ignore_key in state_dict.keys():
|
||||
del state_dict[ignore_key]
|
||||
|
||||
# Rename state_dict keys before saving to file. Do nothing unless overriden in a particular model.
|
||||
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
|
||||
state_dict = self._fix_state_dict_keys_on_save(state_dict)
|
||||
|
||||
if safe_serialization:
|
||||
# Safetensors does not allow tensor aliasing.
|
||||
# We're going to remove aliases before saving
|
||||
@@ -4010,7 +3946,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
with safe_open(resolved_archive_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
if metadata.get("format") == "pt":
|
||||
if metadata is None:
|
||||
# Assume it's a pytorch checkpoint (introduced for timm checkpoints)
|
||||
pass
|
||||
elif metadata.get("format") == "pt":
|
||||
pass
|
||||
elif metadata.get("format") == "tf":
|
||||
from_tf = True
|
||||
@@ -4375,6 +4314,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_load(key):
|
||||
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
|
||||
|
||||
if "beta" in key:
|
||||
return key.replace("beta", "bias")
|
||||
if "gamma" in key:
|
||||
return key.replace("gamma", "weight")
|
||||
|
||||
# to avoid logging parametrized weight norm renaming
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
if "weight_g" in key:
|
||||
return key.replace("weight_g", "parametrizations.weight.original0")
|
||||
if "weight_v" in key:
|
||||
return key.replace("weight_v", "parametrizations.weight.original1")
|
||||
else:
|
||||
if "parametrizations.weight.original0" in key:
|
||||
return key.replace("parametrizations.weight.original0", "weight_g")
|
||||
if "parametrizations.weight.original1" in key:
|
||||
return key.replace("parametrizations.weight.original1", "weight_v")
|
||||
return key
|
||||
|
||||
@classmethod
|
||||
def _fix_state_dict_keys_on_load(cls, state_dict):
|
||||
"""Fixes state dict keys by replacing legacy parameter names with their modern equivalents.
|
||||
Logs if any parameters have been renamed.
|
||||
"""
|
||||
|
||||
renamed_keys = {}
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
for key in state_dict_keys:
|
||||
new_key = cls._fix_state_dict_key_on_load(key)
|
||||
if new_key != key:
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
# add it once for logging
|
||||
if "gamma" in key and "gamma" not in renamed_keys:
|
||||
renamed_keys["gamma"] = (key, new_key)
|
||||
if "beta" in key and "beta" not in renamed_keys:
|
||||
renamed_keys["beta"] = (key, new_key)
|
||||
|
||||
if renamed_keys:
|
||||
warning_msg = f"A pretrained model of type `{cls.__name__}` "
|
||||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
|
||||
for old_key, new_key in renamed_keys.values():
|
||||
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
|
||||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
|
||||
logger.info_once(warning_msg)
|
||||
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_save(key):
|
||||
"""
|
||||
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
|
||||
Do nothing by default, but can be overriden in particular models.
|
||||
"""
|
||||
return key
|
||||
|
||||
def _fix_state_dict_keys_on_save(self, state_dict):
|
||||
"""
|
||||
Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
|
||||
Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
|
||||
"""
|
||||
return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
@@ -4430,27 +4435,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if hf_quantizer is not None:
|
||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
|
||||
|
||||
def _fix_key(key):
|
||||
if "beta" in key:
|
||||
return key.replace("beta", "bias")
|
||||
if "gamma" in key:
|
||||
return key.replace("gamma", "weight")
|
||||
|
||||
# to avoid logging parametrized weight norm renaming
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
if "weight_g" in key:
|
||||
return key.replace("weight_g", "parametrizations.weight.original0")
|
||||
if "weight_v" in key:
|
||||
return key.replace("weight_v", "parametrizations.weight.original1")
|
||||
else:
|
||||
if "parametrizations.weight.original0" in key:
|
||||
return key.replace("parametrizations.weight.original0", "weight_g")
|
||||
if "parametrizations.weight.original1" in key:
|
||||
return key.replace("parametrizations.weight.original1", "weight_v")
|
||||
return key
|
||||
|
||||
original_loaded_keys = loaded_keys
|
||||
loaded_keys = [_fix_key(key) for key in loaded_keys]
|
||||
loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys]
|
||||
|
||||
if len(prefix) > 0:
|
||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
||||
@@ -4615,23 +4601,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
original_loaded_keys,
|
||||
add_prefix_to_model,
|
||||
remove_prefix_from_model,
|
||||
ignore_mismatched_sizes,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
|
||||
# If the checkpoint is sharded, we may not have the key here.
|
||||
if checkpoint_key not in state_dict:
|
||||
continue
|
||||
model_key = checkpoint_key
|
||||
if remove_prefix_from_model:
|
||||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
||||
model_key = f"{prefix}.{checkpoint_key}"
|
||||
model_key = f"{prefix}.{model_key}"
|
||||
elif add_prefix_to_model:
|
||||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
||||
model_key = ".".join(checkpoint_key.split(".")[1:])
|
||||
model_key = ".".join(model_key.split(".")[1:])
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
@@ -4680,6 +4666,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
original_loaded_keys,
|
||||
add_prefix_to_model,
|
||||
remove_prefix_from_model,
|
||||
@@ -4688,9 +4675,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# For GGUF models `state_dict` is never set to None as the state dict is always small
|
||||
if gguf_path:
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
fixed_state_dict,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
@@ -4709,8 +4697,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(
|
||||
model_to_load, state_dict, start_prefix
|
||||
)
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
error_msgs = _load_state_dict_into_model(
|
||||
model_to_load, state_dict, start_prefix, assign_to_params_buffers
|
||||
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -4761,6 +4750,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
mismatched_keys += _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
original_loaded_keys,
|
||||
add_prefix_to_model,
|
||||
remove_prefix_from_model,
|
||||
@@ -4774,9 +4764,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
fixed_state_dict,
|
||||
start_prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
@@ -4797,8 +4788,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
assign_to_params_buffers = check_support_param_buffer_assignment(
|
||||
model_to_load, state_dict, start_prefix
|
||||
)
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
error_msgs += _load_state_dict_into_model(
|
||||
model_to_load, state_dict, start_prefix, assign_to_params_buffers
|
||||
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||
)
|
||||
|
||||
# force memory release
|
||||
@@ -4930,9 +4922,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
|
||||
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
|
||||
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
|
||||
fixed_state_dict = model._fix_state_dict_keys_on_load(state_dict)
|
||||
error_msgs = _load_state_dict_into_meta_model(
|
||||
model,
|
||||
state_dict,
|
||||
fixed_state_dict,
|
||||
start_prefix,
|
||||
expected_keys=expected_keys,
|
||||
hf_quantizer=hf_quantizer,
|
||||
|
||||
@@ -249,6 +249,7 @@ from . import (
|
||||
time_series_transformer,
|
||||
timesformer,
|
||||
timm_backbone,
|
||||
timm_wrapper,
|
||||
trocr,
|
||||
tvp,
|
||||
udop,
|
||||
|
||||
@@ -276,6 +276,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
||||
("timesformer", "TimesformerConfig"),
|
||||
("timm_backbone", "TimmBackboneConfig"),
|
||||
("timm_wrapper", "TimmWrapperConfig"),
|
||||
("trajectory_transformer", "TrajectoryTransformerConfig"),
|
||||
("transfo-xl", "TransfoXLConfig"),
|
||||
("trocr", "TrOCRConfig"),
|
||||
@@ -599,6 +600,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("time_series_transformer", "Time Series Transformer"),
|
||||
("timesformer", "TimeSformer"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
("timm_wrapper", "TimmWrapperModel"),
|
||||
("trajectory_transformer", "Trajectory Transformer"),
|
||||
("transfo-xl", "Transformer-XL"),
|
||||
("trocr", "TrOCR"),
|
||||
|
||||
@@ -30,6 +30,8 @@ from ...utils import (
|
||||
CONFIG_NAME,
|
||||
IMAGE_PROCESSOR_NAME,
|
||||
get_file_from_repo,
|
||||
is_timm_config_dict,
|
||||
is_timm_local_checkpoint,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
@@ -137,6 +139,7 @@ else:
|
||||
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("table-transformer", ("DetrImageProcessor",)),
|
||||
("timesformer", ("VideoMAEImageProcessor",)),
|
||||
("timm_wrapper", ("TimmWrapperImageProcessor",)),
|
||||
("tvlt", ("TvltImageProcessor",)),
|
||||
("tvp", ("TvpImageProcessor",)),
|
||||
("udop", ("LayoutLMv3ImageProcessor",)),
|
||||
@@ -376,6 +379,8 @@ class AutoImageProcessor:
|
||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
||||
execute code present on the Hub on your local machine.
|
||||
image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
|
||||
The name of the file in the model directory to use for the image processor config.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
The values in kwargs of any keys which are image processor attributes will be used to override the
|
||||
loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
|
||||
@@ -415,7 +420,37 @@ class AutoImageProcessor:
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
kwargs["_from_auto"] = True
|
||||
|
||||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
# Resolve the image processor config filename
|
||||
if "image_processor_filename" in kwargs:
|
||||
image_processor_filename = kwargs.pop("image_processor_filename")
|
||||
elif is_timm_local_checkpoint(pretrained_model_name_or_path):
|
||||
image_processor_filename = CONFIG_NAME
|
||||
else:
|
||||
image_processor_filename = IMAGE_PROCESSOR_NAME
|
||||
|
||||
# Load the image processor config
|
||||
try:
|
||||
# Main path for all transformers models and local TimmWrapper checkpoints
|
||||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
||||
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
|
||||
)
|
||||
except Exception as initial_exception:
|
||||
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
|
||||
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
|
||||
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
|
||||
# load `config.json` and if it fails with some error, we raise the initial exception.
|
||||
try:
|
||||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
||||
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
|
||||
)
|
||||
except Exception:
|
||||
raise initial_exception
|
||||
|
||||
# In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
|
||||
# because only timm models have image processing in `config.json`.
|
||||
if not is_timm_config_dict(config_dict):
|
||||
raise initial_exception
|
||||
|
||||
image_processor_class = config_dict.get("image_processor_type", None)
|
||||
image_processor_auto_map = None
|
||||
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
||||
|
||||
@@ -255,6 +255,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("time_series_transformer", "TimeSeriesTransformerModel"),
|
||||
("timesformer", "TimesformerModel"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
("timm_wrapper", "TimmWrapperModel"),
|
||||
("trajectory_transformer", "TrajectoryTransformerModel"),
|
||||
("transfo-xl", "TransfoXLModel"),
|
||||
("tvlt", "TvltModel"),
|
||||
@@ -605,6 +606,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
|
||||
("table-transformer", "TableTransformerModel"),
|
||||
("timesformer", "TimesformerModel"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
("timm_wrapper", "TimmWrapperModel"),
|
||||
("van", "VanModel"),
|
||||
("videomae", "VideoMAEModel"),
|
||||
("vit", "ViTModel"),
|
||||
@@ -690,6 +692,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("swiftformer", "SwiftFormerForImageClassification"),
|
||||
("swin", "SwinForImageClassification"),
|
||||
("swinv2", "Swinv2ForImageClassification"),
|
||||
("timm_wrapper", "TimmWrapperForImageClassification"),
|
||||
("van", "VanForImageClassification"),
|
||||
("vit", "ViTForImageClassification"),
|
||||
("vit_hybrid", "ViTHybridForImageClassification"),
|
||||
|
||||
28
src/transformers/models/timm_wrapper/__init__.py
Normal file
28
src/transformers/models/timm_wrapper/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_timm_wrapper import *
|
||||
from .modeling_timm_wrapper import *
|
||||
from .processing_timm_wrapper import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
@@ -0,0 +1,86 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for TimmWrapper models"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TimmWrapperConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`].
|
||||
|
||||
It is used to instantiate a timm model according to the specified arguments, defining the model.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
do_pooling (`bool`, *optional*, defaults to `True`):
|
||||
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import TimmWrapperModel
|
||||
|
||||
>>> # Initializing a timm model
|
||||
>>> model = TimmWrapperModel.from_pretrained("timm/resnet18.a1_in1k")
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "timm_wrapper"
|
||||
|
||||
def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs):
|
||||
self.initializer_range = initializer_range
|
||||
self.do_pooling = do_pooling
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
|
||||
# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
|
||||
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config
|
||||
# and to avoid duplicate attributes
|
||||
|
||||
num_labels_in_kwargs = kwargs.pop("num_labels", None)
|
||||
num_labels_in_dict = config_dict.pop("num_classes", None)
|
||||
|
||||
# passed num_labels has priority over num_classes in config_dict
|
||||
kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict
|
||||
|
||||
# pop num_classes from "pretrained_cfg",
|
||||
# it is not necessary to have it, only root one is used in timm
|
||||
if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]:
|
||||
config_dict["pretrained_cfg"].pop("num_classes", None)
|
||||
|
||||
return super().from_dict(config_dict, **kwargs)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
output = super().to_dict()
|
||||
output["num_classes"] = self.num_labels
|
||||
return output
|
||||
|
||||
|
||||
__all__ = ["TimmWrapperConfig"]
|
||||
@@ -0,0 +1,138 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from ...image_transforms import to_pil_image
|
||||
from ...image_utils import ImageInput, make_list_of_images
|
||||
from ...utils import TensorType, logging, requires_backends
|
||||
from ...utils.import_utils import is_timm_available, is_torch_available
|
||||
|
||||
|
||||
if is_timm_available():
|
||||
import timm
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TimmWrapperImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
Wrapper class for timm models to be used within transformers.
|
||||
|
||||
Args:
|
||||
pretrained_cfg (`Dict[str, Any]`):
|
||||
The configuration of the pretrained model used to resolve evaluation and
|
||||
training transforms.
|
||||
architecture (`Optional[str]`, *optional*):
|
||||
Name of the architecture of the model.
|
||||
"""
|
||||
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_cfg: Dict[str, Any],
|
||||
architecture: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
requires_backends(self, "timm")
|
||||
super().__init__(architecture=architecture)
|
||||
|
||||
self.data_config = timm.data.resolve_data_config(pretrained_cfg, model=None, verbose=False)
|
||||
self.val_transforms = timm.data.create_transform(**self.data_config, is_training=False)
|
||||
|
||||
# useful for training, see examples/pytorch/image-classification/run_image_classification.py
|
||||
self.train_transforms = timm.data.create_transform(**self.data_config, is_training=True)
|
||||
|
||||
# If `ToTensor` is in the transforms, then the input should be numpy array or PIL image.
|
||||
# Otherwise, the input can be a tensor. In later timm versions, `MaybeToTensor` is used
|
||||
# which can handle both numpy arrays / PIL images and tensors.
|
||||
self._not_supports_tensor_input = any(
|
||||
transform.__class__.__name__ == "ToTensor" for transform in self.val_transforms.transforms
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary.
|
||||
"""
|
||||
output = super().to_dict()
|
||||
output.pop("train_transforms", None)
|
||||
output.pop("val_transforms", None)
|
||||
output.pop("_not_supports_tensor_input", None)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def get_image_processor_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
Get the image processor dict for the model.
|
||||
"""
|
||||
image_processor_filename = kwargs.pop("image_processor_filename", "config.json")
|
||||
return super().get_image_processor_dict(
|
||||
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return.
|
||||
"""
|
||||
if return_tensors != "pt":
|
||||
raise ValueError(f"return_tensors for TimmWrapperImageProcessor must be 'pt', but got {return_tensors}")
|
||||
|
||||
if self._not_supports_tensor_input and isinstance(images, torch.Tensor):
|
||||
images = images.cpu().numpy()
|
||||
|
||||
# If the input is a torch tensor, then no conversion is needed
|
||||
# Otherwise, we need to pass in a list of PIL images
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = self.val_transforms(images)
|
||||
# Add batch dimension if a single image
|
||||
images = images.unsqueeze(0) if images.ndim == 3 else images
|
||||
else:
|
||||
images = make_list_of_images(images)
|
||||
images = [to_pil_image(image) for image in images]
|
||||
images = torch.stack([self.val_transforms(image) for image in images])
|
||||
|
||||
return BatchFeature({"pixel_values": images}, tensor_type=return_tensors)
|
||||
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
# disable it to make checkpoint the same as in `timm` library.
|
||||
logger.warning_once(
|
||||
"The `save_pretrained` method is disabled for TimmWrapperImageProcessor. "
|
||||
"The image processor configuration is saved directly in `config.json` when "
|
||||
"`save_pretrained` is called for saving the model."
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["TimmWrapperImageProcessor"]
|
||||
363
src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Normal file
363
src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...modeling_outputs import ImageClassifierOutput, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_timm_available,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from .configuration_timm_wrapper import TimmWrapperConfig
|
||||
|
||||
|
||||
if is_timm_available():
|
||||
import timm
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimmWrapperModelOutput(ModelOutput):
|
||||
"""
|
||||
Output class for models TimmWrapperModel, containing the last hidden states, an optional pooled output,
|
||||
and optional hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor`):
|
||||
The last hidden state of the model, output before applying the classification head.
|
||||
pooler_output (`torch.FloatTensor`, *optional*):
|
||||
The pooled output derived from the last hidden state, if applicable.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
||||
A tuple containing the intermediate hidden states of the model at the output of each layer or specified layers.
|
||||
Returned if `output_hidden_states=True` is set or if `config.output_hidden_states=True`.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*):
|
||||
A tuple containing the intermediate attention weights of the model at the output of each layer.
|
||||
Returned if `output_attentions=True` is set or if `config.output_attentions=True`.
|
||||
Note: Currently, Timm models do not support attentions output.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor
|
||||
pooler_output: Optional[torch.FloatTensor] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
TIMM_WRAPPER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`TimmWrapperImageProcessor.preprocess`]
|
||||
for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to the `timm` model forward.
|
||||
"""
|
||||
|
||||
|
||||
class TimmWrapperPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
config_class = TimmWrapperConfig
|
||||
_no_split_modules = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision", "timm"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_load(key):
|
||||
"""
|
||||
Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
|
||||
We don't want this behavior for timm wrapped models. Instead, this method adds a
|
||||
"timm_model." prefix to enable loading official timm Hub checkpoints.
|
||||
"""
|
||||
if "timm_model." not in key:
|
||||
return f"timm_model.{key}"
|
||||
return key
|
||||
|
||||
def _fix_state_dict_key_on_save(self, key):
|
||||
"""
|
||||
Overrides original method to remove "timm_model." prefix from state_dict keys.
|
||||
Makes the saved checkpoint compatible with the `timm` library.
|
||||
"""
|
||||
return key.replace("timm_model.", "")
|
||||
|
||||
def load_state_dict(self, state_dict, *args, **kwargs):
|
||||
"""
|
||||
Override original method to fix state_dict keys on load for cases when weights are loaded
|
||||
without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint).
|
||||
"""
|
||||
state_dict = self._fix_state_dict_keys_on_load(state_dict)
|
||||
return super().load_state_dict(state_dict, *args, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""
|
||||
Initialize weights function to properly initialize Linear layer weights.
|
||||
Since model architectures may vary, we assume only the classifier requires
|
||||
initialization, while all other weights should be loaded from the checkpoint.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class TimmWrapperModel(TimmWrapperPreTrainedModel):
|
||||
"""
|
||||
Wrapper class for timm models to be used in transformers.
|
||||
"""
|
||||
|
||||
def __init__(self, config: TimmWrapperConfig):
|
||||
super().__init__(config)
|
||||
# using num_classes=0 to avoid creating classification head
|
||||
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0)
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TimmWrapperModelOutput, config_class=TimmWrapperConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[Union[bool, List[int]]] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
do_pooling: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[TimmWrapperModelOutput, Tuple[Tensor, ...]]:
|
||||
r"""
|
||||
do_pooling (`bool`, *optional*):
|
||||
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the
|
||||
`do_pooling` value from the config is used.
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from urllib.request import urlopen
|
||||
>>> from transformers import AutoModel, AutoImageProcessor
|
||||
|
||||
>>> # Load image
|
||||
>>> image = Image.open(urlopen(
|
||||
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
|
||||
... ))
|
||||
|
||||
>>> # Load model and image processor
|
||||
>>> checkpoint = "timm/resnet50.a1_in1k"
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||
>>> model = AutoModel.from_pretrained(checkpoint).eval()
|
||||
|
||||
>>> # Preprocess image
|
||||
>>> inputs = image_processor(image)
|
||||
|
||||
>>> # Forward pass
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> # Get pooled output
|
||||
>>> pooled_output = outputs.pooler_output
|
||||
|
||||
>>> # Get last hidden state
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
```
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
do_pooling = do_pooling if do_pooling is not None else self.config.do_pooling
|
||||
|
||||
if output_attentions:
|
||||
raise ValueError("Cannot set `output_attentions` for timm models.")
|
||||
|
||||
if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"):
|
||||
raise ValueError(
|
||||
"The 'output_hidden_states' option cannot be set for this timm model. "
|
||||
"To enable this feature, the 'forward_intermediates' method must be implemented "
|
||||
"in the timm model (available in timm versions > 1.*). Please consider using a "
|
||||
"different architecture or updating the timm package to a compatible version."
|
||||
)
|
||||
|
||||
pixel_values = pixel_values.to(self.device, self.dtype)
|
||||
|
||||
if output_hidden_states:
|
||||
# to enable hidden states selection
|
||||
if isinstance(output_hidden_states, (list, tuple)):
|
||||
kwargs["indices"] = output_hidden_states
|
||||
last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs)
|
||||
else:
|
||||
last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs)
|
||||
hidden_states = None
|
||||
|
||||
if do_pooling:
|
||||
# classification head is not created, applying pooling only
|
||||
pooler_output = self.timm_model.forward_head(last_hidden_state)
|
||||
else:
|
||||
pooler_output = None
|
||||
|
||||
if not return_dict:
|
||||
outputs = (last_hidden_state, pooler_output, hidden_states)
|
||||
outputs = tuple(output for output in outputs if output is not None)
|
||||
return outputs
|
||||
|
||||
return TimmWrapperModelOutput(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooler_output,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
|
||||
class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
|
||||
"""
|
||||
Wrapper class for timm models to be used in transformers for image classification.
|
||||
"""
|
||||
|
||||
def __init__(self, config: TimmWrapperConfig):
|
||||
super().__init__(config)
|
||||
|
||||
if config.num_labels == 0:
|
||||
raise ValueError(
|
||||
"You are trying to load weights into `TimmWrapperForImageClassification` from a checkpoint with no classifier head. "
|
||||
"Please specify the number of classes, e.g. `model = TimmWrapperForImageClassification.from_pretrained(..., num_labels=10)`, "
|
||||
"or use `TimmWrapperModel` for feature extraction."
|
||||
)
|
||||
|
||||
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels)
|
||||
self.num_labels = config.num_labels
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=TimmWrapperConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[Union[bool, List[int]]] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[ImageClassifierOutput, Tuple[Tensor, ...]]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from urllib.request import urlopen
|
||||
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor
|
||||
|
||||
>>> # Load image
|
||||
>>> image = Image.open(urlopen(
|
||||
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
|
||||
... ))
|
||||
|
||||
>>> # Load model and image processor
|
||||
>>> checkpoint = "timm/resnet50.a1_in1k"
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
|
||||
|
||||
>>> # Preprocess image
|
||||
>>> inputs = image_processor(image)
|
||||
|
||||
>>> # Forward pass
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
|
||||
>>> # Get top 5 predictions
|
||||
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
|
||||
```
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
if output_attentions:
|
||||
raise ValueError("Cannot set `output_attentions` for timm models.")
|
||||
|
||||
if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"):
|
||||
raise ValueError(
|
||||
"The 'output_hidden_states' option cannot be set for this timm model. "
|
||||
"To enable this feature, the 'forward_intermediates' method must be implemented "
|
||||
"in the timm model (available in timm versions > 1.*). Please consider using a "
|
||||
"different architecture or updating the timm package to a compatible version."
|
||||
)
|
||||
|
||||
pixel_values = pixel_values.to(self.device, self.dtype)
|
||||
|
||||
if output_hidden_states:
|
||||
# to enable hidden states selection
|
||||
if isinstance(output_hidden_states, (list, tuple)):
|
||||
kwargs["indices"] = output_hidden_states
|
||||
last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs)
|
||||
logits = self.timm_model.forward_head(last_hidden_state)
|
||||
else:
|
||||
logits = self.timm_model(pixel_values, **kwargs)
|
||||
hidden_states = None
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (loss, logits, hidden_states)
|
||||
outputs = tuple(output for output in outputs if output is not None)
|
||||
return outputs
|
||||
|
||||
return ImageClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["TimmWrapperPreTrainedModel", "TimmWrapperModel", "TimmWrapperForImageClassification"]
|
||||
@@ -55,6 +55,8 @@ from .generic import (
|
||||
is_tensor,
|
||||
is_tf_symbolic_tensor,
|
||||
is_tf_tensor,
|
||||
is_timm_config_dict,
|
||||
is_timm_local_checkpoint,
|
||||
is_torch_device,
|
||||
is_torch_dtype,
|
||||
is_torch_tensor,
|
||||
|
||||
@@ -9043,6 +9043,27 @@ class TimmBackbone(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TimmWrapperForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TimmWrapperModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TimmWrapperPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TrOCRForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class TimmWrapperImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["timm", "torchvision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["timm", "torchvision"])
|
||||
@@ -16,6 +16,8 @@ Generic utilities
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import OrderedDict, UserDict
|
||||
@@ -24,7 +26,7 @@ from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict
|
||||
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
@@ -867,3 +869,36 @@ class LossKwargs(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
num_items_in_batch: Optional[int]
|
||||
|
||||
|
||||
def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks whether a config dict is a timm config dict."""
|
||||
return "pretrained_cfg" in config_dict
|
||||
|
||||
|
||||
def is_timm_local_checkpoint(pretrained_model_path: str) -> bool:
|
||||
"""
|
||||
Checks whether a checkpoint is a timm model checkpoint.
|
||||
"""
|
||||
if pretrained_model_path is None:
|
||||
return False
|
||||
|
||||
# in case it's Path, not str
|
||||
pretrained_model_path = str(pretrained_model_path)
|
||||
|
||||
is_file = os.path.isfile(pretrained_model_path)
|
||||
is_dir = os.path.isdir(pretrained_model_path)
|
||||
|
||||
# pretrained_model_path is a file
|
||||
if is_file and pretrained_model_path.endswith(".json"):
|
||||
with open(pretrained_model_path, "r") as f:
|
||||
config_dict = json.load(f)
|
||||
return is_timm_config_dict(config_dict)
|
||||
|
||||
# pretrained_model_path is a directory with a config.json
|
||||
if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")):
|
||||
with open(os.path.join(pretrained_model_path, "config.json"), "r") as f:
|
||||
config_dict = json.load(f)
|
||||
return is_timm_config_dict(config_dict)
|
||||
|
||||
return False
|
||||
|
||||
0
tests/models/timm_wrapper/__init__.py
Normal file
0
tests/models/timm_wrapper/__init__.py
Normal file
103
tests/models/timm_wrapper/test_image_processing_timm_wrapper.py
Normal file
103
tests/models/timm_wrapper/test_image_processing_timm_wrapper.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_torchvision, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import TimmWrapperConfig, TimmWrapperImageProcessor
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
@require_torchvision
|
||||
class TimmWrapperImageProcessingTest(unittest.TestCase):
|
||||
image_processing_class = TimmWrapperImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
config = TimmWrapperConfig.from_pretrained("timm/resnet18.a1_in1k")
|
||||
config.save_pretrained(self.temp_dir.name)
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_load_from_hub(self):
|
||||
image_processor = TimmWrapperImageProcessor.from_pretrained("timm/resnet18.a1_in1k")
|
||||
self.assertIsInstance(image_processor, TimmWrapperImageProcessor)
|
||||
|
||||
def test_load_from_local_dir(self):
|
||||
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
|
||||
self.assertIsInstance(image_processor, TimmWrapperImageProcessor)
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
|
||||
self.assertTrue(hasattr(image_processor, "data_config"))
|
||||
self.assertTrue(hasattr(image_processor, "val_transforms"))
|
||||
self.assertTrue(hasattr(image_processor, "train_transforms"))
|
||||
|
||||
def test_image_processor_call_numpy(self):
|
||||
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
|
||||
|
||||
single_image = np.random.randint(256, size=(256, 256, 3), dtype=np.uint8)
|
||||
batch_images = [single_image, single_image, single_image]
|
||||
|
||||
# single image
|
||||
pixel_values = image_processor(single_image).pixel_values
|
||||
self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
|
||||
|
||||
# batch images
|
||||
pixel_values = image_processor(batch_images).pixel_values
|
||||
self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
|
||||
|
||||
def test_image_processor_call_pil(self):
|
||||
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
|
||||
|
||||
single_image = Image.fromarray(np.random.randint(256, size=(256, 256, 3), dtype=np.uint8))
|
||||
batch_images = [single_image, single_image, single_image]
|
||||
|
||||
# single image
|
||||
pixel_values = image_processor(single_image).pixel_values
|
||||
self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
|
||||
|
||||
# batch images
|
||||
pixel_values = image_processor(batch_images).pixel_values
|
||||
self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
|
||||
|
||||
def test_image_processor_call_tensor(self):
|
||||
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
|
||||
|
||||
single_image = torch.from_numpy(np.random.randint(256, size=(3, 256, 256), dtype=np.uint8)).float()
|
||||
batch_images = [single_image, single_image, single_image]
|
||||
|
||||
# single image
|
||||
pixel_values = image_processor(single_image).pixel_values
|
||||
self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
|
||||
|
||||
# batch images
|
||||
pixel_values = image_processor(batch_images).pixel_values
|
||||
self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
|
||||
366
tests/models/timm_wrapper/test_modeling_timm_wrapper.py
Normal file
366
tests/models/timm_wrapper/test_modeling_timm_wrapper.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_timm,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils.import_utils import is_timm_available, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import TimmWrapperConfig, TimmWrapperForImageClassification, TimmWrapperModel
|
||||
|
||||
|
||||
if is_timm_available():
|
||||
import timm
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import TimmWrapperImageProcessor
|
||||
|
||||
|
||||
class TimmWrapperModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
model_name="timm/resnet18.a1_in1k",
|
||||
batch_size=3,
|
||||
image_size=32,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.model_name = model_name
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return TimmWrapperConfig.from_pretrained(self.model_name)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = TimmWrapperModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(
|
||||
result.feature_map[-1].shape,
|
||||
(self.batch_size, model.channels[-1], 14, 14),
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_timm
|
||||
class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TimmWrapperModel, TimmWrapperForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"image-feature-extraction": TimmWrapperModel, "image-classification": TimmWrapperForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
has_attentions = False
|
||||
test_model_parallel = False
|
||||
|
||||
def setUp(self):
|
||||
self.config_class = TimmWrapperConfig
|
||||
self.model_tester = TimmWrapperModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self,
|
||||
config_class=self.config_class,
|
||||
has_text_modality=False,
|
||||
common_properties=[],
|
||||
model_name="timm/resnet18.a1_in1k",
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# check all hidden states
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict, output_hidden_states=True)
|
||||
self.assertTrue(
|
||||
len(outputs.hidden_states) == 5, f"expected 5 hidden states, but got {len(outputs.hidden_states)}"
|
||||
)
|
||||
expected_shapes = [[16, 16], [8, 8], [4, 4], [2, 2], [1, 1]]
|
||||
resulted_shapes = [list(h.shape[2:]) for h in outputs.hidden_states]
|
||||
self.assertListEqual(expected_shapes, resulted_shapes)
|
||||
|
||||
# check we can select hidden states by indices
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict, output_hidden_states=[-2, -1])
|
||||
self.assertTrue(
|
||||
len(outputs.hidden_states) == 2, f"expected 2 hidden states, but got {len(outputs.hidden_states)}"
|
||||
)
|
||||
expected_shapes = [[2, 2], [1, 1]]
|
||||
resulted_shapes = [list(h.shape[2:]) for h in outputs.hidden_states]
|
||||
self.assertListEqual(expected_shapes, resulted_shapes)
|
||||
|
||||
@unittest.skip(reason="TimmWrapper models doesn't have inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TimmWrapper models doesn't have inputs_embeds")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TimmWrapper doesn't support output_attentions=True.")
|
||||
def test_torchscript_output_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TimmWrapper doesn't support this.")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TimmWrapper initialization is managed on the timm side")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Need to use a timm model and there is no tiny model available.")
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_do_pooling_option(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.do_pooling = False
|
||||
|
||||
model = TimmWrapperModel._from_config(config)
|
||||
|
||||
# check there is no pooling
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
self.assertIsNone(output.pooler_output)
|
||||
|
||||
# check there is pooler output
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict, do_pooling=True)
|
||||
self.assertIsNotNone(output.pooler_output)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_timm
|
||||
@require_vision
|
||||
class TimmWrapperModelIntegrationTest(unittest.TestCase):
|
||||
# some popular ones
|
||||
model_names_to_test = [
|
||||
"vit_small_patch16_384.augreg_in21k_ft_in1k",
|
||||
"resnet50.a1_in1k",
|
||||
"tf_mobilenetv3_large_minimal_100.in1k",
|
||||
"swin_tiny_patch4_window7_224.ms_in1k",
|
||||
"ese_vovnet19b_dw.ra_in1k",
|
||||
"hrnet_w18.ms_aug_in1k",
|
||||
]
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head(self):
|
||||
checkpoint = "timm/resnet18.a1_in1k"
|
||||
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map=torch_device).eval()
|
||||
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the shape and logits
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_label = 281 # tabby cat
|
||||
self.assertEqual(torch.argmax(outputs.logits).item(), expected_label)
|
||||
|
||||
expected_slice = torch.tensor([-11.2618, -9.6192, -10.3205]).to(torch_device)
|
||||
resulted_slice = outputs.logits[0, :3]
|
||||
is_close = torch.allclose(resulted_slice, expected_slice, atol=1e-3)
|
||||
self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}")
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_inference_image_classification_quantized(self):
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
checkpoint = "timm/vit_small_patch16_384.augreg_in21k_ft_in1k"
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
model = TimmWrapperForImageClassification.from_pretrained(
|
||||
checkpoint, quantization_config=quantization_config, device_map=torch_device
|
||||
).eval()
|
||||
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the shape and logits
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_label = 281 # tabby cat
|
||||
self.assertEqual(torch.argmax(outputs.logits).item(), expected_label)
|
||||
|
||||
expected_slice = torch.tensor([-2.4043, 1.4492, -0.5127]).to(outputs.logits.dtype)
|
||||
resulted_slice = outputs.logits[0, :3].cpu()
|
||||
is_close = torch.allclose(resulted_slice, expected_slice, atol=0.1)
|
||||
self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}")
|
||||
|
||||
@slow
|
||||
def test_transformers_model_for_classification_is_equivalent_to_timm(self):
|
||||
# check that wrapper logits are the same as timm model logits
|
||||
|
||||
image = prepare_img()
|
||||
|
||||
for model_name in self.model_names_to_test:
|
||||
checkpoint = f"timm/{model_name}"
|
||||
|
||||
with self.subTest(msg=model_name):
|
||||
# prepare inputs
|
||||
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
|
||||
pixel_values = image_processor(images=image).pixel_values.to(torch_device)
|
||||
|
||||
# load models
|
||||
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map=torch_device).eval()
|
||||
timm_model = timm.create_model(model_name, pretrained=True).to(torch_device).eval()
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(pixel_values)
|
||||
timm_outputs = timm_model(pixel_values)
|
||||
|
||||
# check shape is the same
|
||||
self.assertEqual(outputs.logits.shape, timm_outputs.shape)
|
||||
|
||||
# check logits are the same
|
||||
diff = (outputs.logits - timm_outputs).max().item()
|
||||
self.assertLess(diff, 1e-4)
|
||||
|
||||
@slow
|
||||
def test_transformers_model_is_equivalent_to_timm(self):
|
||||
# check that wrapper logits are the same as timm model logits
|
||||
|
||||
image = prepare_img()
|
||||
|
||||
models_to_test = ["vit_small_patch16_224.dino"] + self.model_names_to_test
|
||||
|
||||
for model_name in models_to_test:
|
||||
checkpoint = f"timm/{model_name}"
|
||||
|
||||
with self.subTest(msg=model_name):
|
||||
# prepare inputs
|
||||
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
|
||||
pixel_values = image_processor(images=image).pixel_values.to(torch_device)
|
||||
|
||||
# load models
|
||||
model = TimmWrapperModel.from_pretrained(checkpoint, device_map=torch_device).eval()
|
||||
timm_model = timm.create_model(model_name, pretrained=True, num_classes=0).to(torch_device).eval()
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(pixel_values)
|
||||
timm_outputs = timm_model(pixel_values)
|
||||
|
||||
# check shape is the same
|
||||
self.assertEqual(outputs.pooler_output.shape, timm_outputs.shape)
|
||||
|
||||
# check logits are the same
|
||||
diff = (outputs.pooler_output - timm_outputs).max().item()
|
||||
self.assertLess(diff, 1e-4)
|
||||
|
||||
@slow
|
||||
def test_save_load_to_timm(self):
|
||||
# test that timm model can be loaded to transformers, saved and then loaded back into timm
|
||||
|
||||
model = TimmWrapperForImageClassification.from_pretrained(
|
||||
"timm/resnet18.a1_in1k", num_labels=10, ignore_mismatched_sizes=True
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# there is no direct way to load timm model from folder, use the same config + path to weights
|
||||
timm_model = timm.create_model(
|
||||
"resnet18", num_classes=10, checkpoint_path=f"{tmpdirname}/model.safetensors"
|
||||
)
|
||||
|
||||
# check that all weights are the same after reload
|
||||
different_weights = []
|
||||
for (name1, param1), (name2, param2) in zip(
|
||||
model.timm_model.named_parameters(), timm_model.named_parameters()
|
||||
):
|
||||
if param1.shape != param2.shape or not torch.equal(param1, param2):
|
||||
different_weights.append((name1, name2))
|
||||
|
||||
if different_weights:
|
||||
self.fail(f"Found different weights after reloading: {different_weights}")
|
||||
@@ -3443,6 +3443,7 @@ class ModelTesterMixin:
|
||||
"Data2VecAudioForSequenceClassification",
|
||||
"UniSpeechForSequenceClassification",
|
||||
"PvtForImageClassification",
|
||||
"TimmWrapperForImageClassification",
|
||||
]
|
||||
special_param_names = [
|
||||
r"^bit\.",
|
||||
@@ -3463,6 +3464,7 @@ class ModelTesterMixin:
|
||||
r"^swiftformer\.",
|
||||
r"^swinv2\.",
|
||||
r"^transformers\.models\.swiftformer\.",
|
||||
r"^timm_model\.",
|
||||
r"^unispeech\.",
|
||||
r"^unispeech_sat\.",
|
||||
r"^vision_model\.",
|
||||
|
||||
@@ -41,6 +41,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"RagConfig",
|
||||
"SpeechEncoderDecoderConfig",
|
||||
"TimmBackboneConfig",
|
||||
"TimmWrapperConfig",
|
||||
"VisionEncoderDecoderConfig",
|
||||
"VisionTextDualEncoderConfig",
|
||||
"LlamaConfig",
|
||||
|
||||
Reference in New Issue
Block a user