Add TimmWrapper (#34564)
* Add files * Init * Add TimmWrapperModel * Fix up * Some fixes * Fix up * Remove old file * Sort out import orders * Fix some model loading * Compatible with pipeline and trainer * Fix up * Delete test_timm_model_1/config.json * Remove accidentally commited files * Delete src/transformers/models/modeling_timm_wrapper.py * Remove empty imports; fix transformations applied * Tidy up * Add image classifcation model to special cases * Create pretrained model; enable device_map='auto' * Enable most tests; fix init order * Sort imports * [run-slow] timm_wrapper * Pass num_classes into timm.create_model * Remove train transforms from image processor * Update timm creation with pretrained=False * Fix gamma/beta issue for timm models * Fixing gamma and beta renaming for timm models * Simplify config and model creation * Remove attn_implementation diff * Fixup * Docstrings * Fix warning msg text according to test case * Fix device_map auto * Set dtype and device for pixel_values in forward * Enable output hidden states * Enable tests for hidden_states and model parallel * Remove default scriptable arg * Refactor inner model * Update timm version * Fix _find_mismatched_keys function * Change inheritance for Classification model (fix weights loading with device_map) * Minor bugfix * Disable save pretrained for image processor * Rename hook method for loaded keys correction * Rename state dict keys on save, remove `timm_model` prefix, make checkpoint compatible with `timm` * Managing num_labels <-> num_classes attributes * Enable loading checkpoints in Trainer to resume training * Update error message for output_hidden_states * Add output hidden states test * Decouple base and classification models * Add more test cases * Add save-load-to-timm test * Fix test name * Fixup * Add do_pooling * Add test for do_pooling * Fix doc * Add tests for TimmWrapperModel * Add validation for `num_classes=0` in timm config + test for DINO checkpoint * Adjust atol for test * Fix docs * dev-ci * dev-ci * Add tests for image processor * Update docs * Update init to new format * Update docs in configuration * Fix some docs in image processor * Improve docs for modeling * fix for is_timm_checkpoint * Update code examples * Fix header * Fix typehint * Increase tolerance a bit * Fix Path * Fixing model parallel tests * Disable "parallel" tests * Add comment for metadata * Refactor AutoImageProcessor for timm wrapper loading * Remove custom test_model_outputs_equivalence * Add require_timm decorator * Fix comment * Make image processor work with older timm versions and tensor input * Save config instead of whole model in image processor tests * Add docstring for `image_processor_filename` * Sanitize kwargs for timm image processor * Fix doc style * Update check for tensor input * Update normalize * Remove _load_timm_model function --------- Co-authored-by: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bcc50cc7ce
commit
5fcf6286bf
@@ -705,6 +705,8 @@
|
|||||||
title: Swin2SR
|
title: Swin2SR
|
||||||
- local: model_doc/table-transformer
|
- local: model_doc/table-transformer
|
||||||
title: Table Transformer
|
title: Table Transformer
|
||||||
|
- local: model_doc/timm_wrapper
|
||||||
|
title: Timm Wrapper
|
||||||
- local: model_doc/upernet
|
- local: model_doc/upernet
|
||||||
title: UperNet
|
title: UperNet
|
||||||
- local: model_doc/van
|
- local: model_doc/van
|
||||||
|
|||||||
@@ -321,6 +321,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ |
|
| [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ |
|
||||||
| [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ |
|
| [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ |
|
||||||
| [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ |
|
| [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ |
|
||||||
|
| [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ |
|
||||||
| [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ |
|
| [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ |
|
||||||
| [Transformer-XL](model_doc/transfo-xl) | ✅ | ✅ | ❌ |
|
| [Transformer-XL](model_doc/transfo-xl) | ✅ | ✅ | ❌ |
|
||||||
| [TrOCR](model_doc/trocr) | ✅ | ❌ | ❌ |
|
| [TrOCR](model_doc/trocr) | ✅ | ❌ | ❌ |
|
||||||
|
|||||||
67
docs/source/en/model_doc/timm_wrapper.md
Normal file
67
docs/source/en/model_doc/timm_wrapper.md
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# TimmWrapper
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Helper class to enable loading timm models to be used with the transformers library and its autoclasses.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> from urllib.request import urlopen
|
||||||
|
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor
|
||||||
|
|
||||||
|
>>> # Load image
|
||||||
|
>>> image = Image.open(urlopen(
|
||||||
|
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
|
||||||
|
... ))
|
||||||
|
|
||||||
|
>>> # Load model and image processor
|
||||||
|
>>> checkpoint = "timm/resnet50.a1_in1k"
|
||||||
|
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||||
|
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
|
||||||
|
|
||||||
|
>>> # Preprocess image
|
||||||
|
>>> inputs = image_processor(image)
|
||||||
|
|
||||||
|
>>> # Forward pass
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
... logits = model(**inputs).logits
|
||||||
|
|
||||||
|
>>> # Get top 5 predictions
|
||||||
|
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
## TimmWrapperConfig
|
||||||
|
|
||||||
|
[[autodoc]] TimmWrapperConfig
|
||||||
|
|
||||||
|
## TimmWrapperImageProcessor
|
||||||
|
|
||||||
|
[[autodoc]] TimmWrapperImageProcessor
|
||||||
|
- preprocess
|
||||||
|
|
||||||
|
## TimmWrapperModel
|
||||||
|
|
||||||
|
[[autodoc]] TimmWrapperModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## TimmWrapperForImageClassification
|
||||||
|
|
||||||
|
[[autodoc]] TimmWrapperForImageClassification
|
||||||
|
- forward
|
||||||
@@ -42,6 +42,7 @@ from transformers import (
|
|||||||
AutoImageProcessor,
|
AutoImageProcessor,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
|
TimmWrapperImageProcessor,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
set_seed,
|
set_seed,
|
||||||
@@ -329,15 +330,20 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Define torchvision transforms to be applied to each image.
|
# Define torchvision transforms to be applied to each image.
|
||||||
|
if isinstance(image_processor, TimmWrapperImageProcessor):
|
||||||
|
_train_transforms = image_processor.train_transforms
|
||||||
|
_val_transforms = image_processor.val_transforms
|
||||||
|
else:
|
||||||
if "shortest_edge" in image_processor.size:
|
if "shortest_edge" in image_processor.size:
|
||||||
size = image_processor.size["shortest_edge"]
|
size = image_processor.size["shortest_edge"]
|
||||||
else:
|
else:
|
||||||
size = (image_processor.size["height"], image_processor.size["width"])
|
size = (image_processor.size["height"], image_processor.size["width"])
|
||||||
normalize = (
|
|
||||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
# Create normalization transform
|
||||||
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
|
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std"):
|
||||||
else Lambda(lambda x: x)
|
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||||
)
|
else:
|
||||||
|
normalize = Lambda(lambda x: x)
|
||||||
_train_transforms = Compose(
|
_train_transforms = Compose(
|
||||||
[
|
[
|
||||||
RandomResizedCrop(size),
|
RandomResizedCrop(size),
|
||||||
|
|||||||
@@ -782,6 +782,7 @@ _import_structure = {
|
|||||||
"models.time_series_transformer": ["TimeSeriesTransformerConfig"],
|
"models.time_series_transformer": ["TimeSeriesTransformerConfig"],
|
||||||
"models.timesformer": ["TimesformerConfig"],
|
"models.timesformer": ["TimesformerConfig"],
|
||||||
"models.timm_backbone": ["TimmBackboneConfig"],
|
"models.timm_backbone": ["TimmBackboneConfig"],
|
||||||
|
"models.timm_wrapper": ["TimmWrapperConfig"],
|
||||||
"models.trocr": [
|
"models.trocr": [
|
||||||
"TrOCRConfig",
|
"TrOCRConfig",
|
||||||
"TrOCRProcessor",
|
"TrOCRProcessor",
|
||||||
@@ -1272,6 +1273,18 @@ else:
|
|||||||
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
|
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
|
||||||
_import_structure["models.vit"].append("ViTImageProcessorFast")
|
_import_structure["models.vit"].append("ViTImageProcessorFast")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torchvision_available() and not is_timm_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils import dummy_timm_and_torchvision_objects
|
||||||
|
|
||||||
|
_import_structure["utils.dummy_timm_and_torchvision_objects"] = [
|
||||||
|
name for name in dir(dummy_timm_and_torchvision_objects) if not name.startswith("_")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
_import_structure["models.timm_wrapper"].extend(["TimmWrapperImageProcessor"])
|
||||||
|
|
||||||
# PyTorch-backed objects
|
# PyTorch-backed objects
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
@@ -3532,6 +3545,9 @@ else:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.timm_backbone"].extend(["TimmBackbone"])
|
_import_structure["models.timm_backbone"].extend(["TimmBackbone"])
|
||||||
|
_import_structure["models.timm_wrapper"].extend(
|
||||||
|
["TimmWrapperForImageClassification", "TimmWrapperModel", "TimmWrapperPreTrainedModel"]
|
||||||
|
)
|
||||||
_import_structure["models.trocr"].extend(
|
_import_structure["models.trocr"].extend(
|
||||||
[
|
[
|
||||||
"TrOCRForCausalLM",
|
"TrOCRForCausalLM",
|
||||||
@@ -5734,6 +5750,7 @@ if TYPE_CHECKING:
|
|||||||
TimesformerConfig,
|
TimesformerConfig,
|
||||||
)
|
)
|
||||||
from .models.timm_backbone import TimmBackboneConfig
|
from .models.timm_backbone import TimmBackboneConfig
|
||||||
|
from .models.timm_wrapper import TimmWrapperConfig
|
||||||
from .models.trocr import (
|
from .models.trocr import (
|
||||||
TrOCRConfig,
|
TrOCRConfig,
|
||||||
TrOCRProcessor,
|
TrOCRProcessor,
|
||||||
@@ -6227,6 +6244,14 @@ if TYPE_CHECKING:
|
|||||||
from .models.rt_detr import RTDetrImageProcessorFast
|
from .models.rt_detr import RTDetrImageProcessorFast
|
||||||
from .models.vit import ViTImageProcessorFast
|
from .models.vit import ViTImageProcessorFast
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torchvision_available() and not is_timm_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils.dummy_timm_and_torchvision_objects import *
|
||||||
|
else:
|
||||||
|
from .models.timm_wrapper import TimmWrapperImageProcessor
|
||||||
|
|
||||||
# Modeling
|
# Modeling
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
@@ -8037,6 +8062,11 @@ if TYPE_CHECKING:
|
|||||||
TimesformerPreTrainedModel,
|
TimesformerPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.timm_backbone import TimmBackbone
|
from .models.timm_backbone import TimmBackbone
|
||||||
|
from .models.timm_wrapper import (
|
||||||
|
TimmWrapperForImageClassification,
|
||||||
|
TimmWrapperModel,
|
||||||
|
TimmWrapperPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.trocr import (
|
from .models.trocr import (
|
||||||
TrOCRForCausalLM,
|
TrOCRForCausalLM,
|
||||||
TrOCRPreTrainedModel,
|
TrOCRPreTrainedModel,
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from .utils import (
|
|||||||
download_url,
|
download_url,
|
||||||
extract_commit_hash,
|
extract_commit_hash,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
|
is_timm_config_dict,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
@@ -702,6 +703,11 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
|
config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
|
||||||
config_dict["custom_pipelines"], pretrained_model_name_or_path
|
config_dict["custom_pipelines"], pretrained_model_name_or_path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# timm models are not saved with the model_type in the config file
|
||||||
|
if "model_type" not in config_dict and is_timm_config_dict(config_dict):
|
||||||
|
config_dict["model_type"] = "timm_wrapper"
|
||||||
|
|
||||||
return config_dict, kwargs
|
return config_dict, kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -285,6 +285,8 @@ class ImageProcessingMixin(PushToHubMixin):
|
|||||||
subfolder (`str`, *optional*, defaults to `""`):
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||||
specify the folder name here.
|
specify the folder name here.
|
||||||
|
image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
|
||||||
|
The name of the file in the model directory to use for the image processor config.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
|
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
|
||||||
@@ -298,6 +300,7 @@ class ImageProcessingMixin(PushToHubMixin):
|
|||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
subfolder = kwargs.pop("subfolder", "")
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
|
image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
|
||||||
|
|
||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
@@ -324,7 +327,7 @@ class ImageProcessingMixin(PushToHubMixin):
|
|||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
image_processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
|
image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
|
||||||
if os.path.isfile(pretrained_model_name_or_path):
|
if os.path.isfile(pretrained_model_name_or_path):
|
||||||
resolved_image_processor_file = pretrained_model_name_or_path
|
resolved_image_processor_file = pretrained_model_name_or_path
|
||||||
is_local = True
|
is_local = True
|
||||||
@@ -332,7 +335,7 @@ class ImageProcessingMixin(PushToHubMixin):
|
|||||||
image_processor_file = pretrained_model_name_or_path
|
image_processor_file = pretrained_model_name_or_path
|
||||||
resolved_image_processor_file = download_url(pretrained_model_name_or_path)
|
resolved_image_processor_file = download_url(pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
image_processor_file = IMAGE_PROCESSOR_NAME
|
image_processor_file = image_processor_filename
|
||||||
try:
|
try:
|
||||||
# Load from local folder or from cache or download from model Hub and cache
|
# Load from local folder or from cache or download from model Hub and cache
|
||||||
resolved_image_processor_file = cached_file(
|
resolved_image_processor_file = cached_file(
|
||||||
@@ -358,7 +361,7 @@ class ImageProcessingMixin(PushToHubMixin):
|
|||||||
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
||||||
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||||
f" directory containing a {IMAGE_PROCESSOR_NAME} file"
|
f" directory containing a {image_processor_filename} file"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -503,7 +503,7 @@ def load_state_dict(
|
|||||||
# Check format of the archive
|
# Check format of the archive
|
||||||
with safe_open(checkpoint_file, framework="pt") as f:
|
with safe_open(checkpoint_file, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
||||||
"you save your model with the `save_pretrained` method."
|
"you save your model with the `save_pretrained` method."
|
||||||
@@ -652,36 +652,6 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
|
|||||||
|
|
||||||
|
|
||||||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
|
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
|
||||||
old_keys = []
|
|
||||||
new_keys = []
|
|
||||||
renamed_keys = {}
|
|
||||||
renamed_gamma = {}
|
|
||||||
renamed_beta = {}
|
|
||||||
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
|
|
||||||
for key in state_dict.keys():
|
|
||||||
new_key = None
|
|
||||||
if "gamma" in key:
|
|
||||||
# We add only the first key as an example
|
|
||||||
new_key = key.replace("gamma", "weight")
|
|
||||||
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
|
|
||||||
if "beta" in key:
|
|
||||||
# We add only the first key as an example
|
|
||||||
new_key = key.replace("beta", "bias")
|
|
||||||
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
|
|
||||||
if new_key:
|
|
||||||
old_keys.append(key)
|
|
||||||
new_keys.append(new_key)
|
|
||||||
renamed_keys = {**renamed_gamma, **renamed_beta}
|
|
||||||
if renamed_keys:
|
|
||||||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
|
|
||||||
for old_key, new_key in renamed_keys.items():
|
|
||||||
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
|
|
||||||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
|
|
||||||
logger.info_once(warning_msg)
|
|
||||||
for old_key, new_key in zip(old_keys, new_keys):
|
|
||||||
state_dict[new_key] = state_dict.pop(old_key)
|
|
||||||
|
|
||||||
# copy state_dict so _load_from_state_dict can modify it
|
# copy state_dict so _load_from_state_dict can modify it
|
||||||
metadata = getattr(state_dict, "_metadata", None)
|
metadata = getattr(state_dict, "_metadata", None)
|
||||||
state_dict = state_dict.copy()
|
state_dict = state_dict.copy()
|
||||||
@@ -812,46 +782,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
|
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
|
|
||||||
old_keys = []
|
|
||||||
new_keys = []
|
|
||||||
renamed_gamma = {}
|
|
||||||
renamed_beta = {}
|
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
warning_msg = f"This model {type(model)}"
|
|
||||||
for key in state_dict.keys():
|
|
||||||
new_key = None
|
|
||||||
if "gamma" in key:
|
|
||||||
# We add only the first key as an example
|
|
||||||
new_key = key.replace("gamma", "weight")
|
|
||||||
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
|
|
||||||
if "beta" in key:
|
|
||||||
# We add only the first key as an example
|
|
||||||
new_key = key.replace("beta", "bias")
|
|
||||||
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
|
|
||||||
|
|
||||||
# To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary.
|
|
||||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
|
||||||
if "weight_g" in key:
|
|
||||||
new_key = key.replace("weight_g", "parametrizations.weight.original0")
|
|
||||||
if "weight_v" in key:
|
|
||||||
new_key = key.replace("weight_v", "parametrizations.weight.original1")
|
|
||||||
else:
|
|
||||||
if "parametrizations.weight.original0" in key:
|
|
||||||
new_key = key.replace("parametrizations.weight.original0", "weight_g")
|
|
||||||
if "parametrizations.weight.original1" in key:
|
|
||||||
new_key = key.replace("parametrizations.weight.original1", "weight_v")
|
|
||||||
if new_key:
|
|
||||||
old_keys.append(key)
|
|
||||||
new_keys.append(new_key)
|
|
||||||
renamed_keys = {**renamed_gamma, **renamed_beta}
|
|
||||||
if renamed_keys:
|
|
||||||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
|
|
||||||
for old_key, new_key in renamed_keys.items():
|
|
||||||
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
|
|
||||||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
|
|
||||||
logger.info_once(warning_msg)
|
|
||||||
for old_key, new_key in zip(old_keys, new_keys):
|
|
||||||
state_dict[new_key] = state_dict.pop(old_key)
|
|
||||||
|
|
||||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||||
|
|
||||||
@@ -2888,6 +2819,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
for ignore_key in self._keys_to_ignore_on_save:
|
for ignore_key in self._keys_to_ignore_on_save:
|
||||||
if ignore_key in state_dict.keys():
|
if ignore_key in state_dict.keys():
|
||||||
del state_dict[ignore_key]
|
del state_dict[ignore_key]
|
||||||
|
|
||||||
|
# Rename state_dict keys before saving to file. Do nothing unless overriden in a particular model.
|
||||||
|
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
|
||||||
|
state_dict = self._fix_state_dict_keys_on_save(state_dict)
|
||||||
|
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
# Safetensors does not allow tensor aliasing.
|
# Safetensors does not allow tensor aliasing.
|
||||||
# We're going to remove aliases before saving
|
# We're going to remove aliases before saving
|
||||||
@@ -4010,7 +3946,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
with safe_open(resolved_archive_file, framework="pt") as f:
|
with safe_open(resolved_archive_file, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
|
|
||||||
if metadata.get("format") == "pt":
|
if metadata is None:
|
||||||
|
# Assume it's a pytorch checkpoint (introduced for timm checkpoints)
|
||||||
|
pass
|
||||||
|
elif metadata.get("format") == "pt":
|
||||||
pass
|
pass
|
||||||
elif metadata.get("format") == "tf":
|
elif metadata.get("format") == "tf":
|
||||||
from_tf = True
|
from_tf = True
|
||||||
@@ -4375,6 +4314,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_state_dict_key_on_load(key):
|
||||||
|
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
|
||||||
|
|
||||||
|
if "beta" in key:
|
||||||
|
return key.replace("beta", "bias")
|
||||||
|
if "gamma" in key:
|
||||||
|
return key.replace("gamma", "weight")
|
||||||
|
|
||||||
|
# to avoid logging parametrized weight norm renaming
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
if "weight_g" in key:
|
||||||
|
return key.replace("weight_g", "parametrizations.weight.original0")
|
||||||
|
if "weight_v" in key:
|
||||||
|
return key.replace("weight_v", "parametrizations.weight.original1")
|
||||||
|
else:
|
||||||
|
if "parametrizations.weight.original0" in key:
|
||||||
|
return key.replace("parametrizations.weight.original0", "weight_g")
|
||||||
|
if "parametrizations.weight.original1" in key:
|
||||||
|
return key.replace("parametrizations.weight.original1", "weight_v")
|
||||||
|
return key
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _fix_state_dict_keys_on_load(cls, state_dict):
|
||||||
|
"""Fixes state dict keys by replacing legacy parameter names with their modern equivalents.
|
||||||
|
Logs if any parameters have been renamed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
renamed_keys = {}
|
||||||
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
for key in state_dict_keys:
|
||||||
|
new_key = cls._fix_state_dict_key_on_load(key)
|
||||||
|
if new_key != key:
|
||||||
|
state_dict[new_key] = state_dict.pop(key)
|
||||||
|
|
||||||
|
# add it once for logging
|
||||||
|
if "gamma" in key and "gamma" not in renamed_keys:
|
||||||
|
renamed_keys["gamma"] = (key, new_key)
|
||||||
|
if "beta" in key and "beta" not in renamed_keys:
|
||||||
|
renamed_keys["beta"] = (key, new_key)
|
||||||
|
|
||||||
|
if renamed_keys:
|
||||||
|
warning_msg = f"A pretrained model of type `{cls.__name__}` "
|
||||||
|
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
|
||||||
|
for old_key, new_key in renamed_keys.values():
|
||||||
|
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
|
||||||
|
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
|
||||||
|
logger.info_once(warning_msg)
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_state_dict_key_on_save(key):
|
||||||
|
"""
|
||||||
|
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
|
||||||
|
Do nothing by default, but can be overriden in particular models.
|
||||||
|
"""
|
||||||
|
return key
|
||||||
|
|
||||||
|
def _fix_state_dict_keys_on_save(self, state_dict):
|
||||||
|
"""
|
||||||
|
Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
|
||||||
|
Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
|
||||||
|
"""
|
||||||
|
return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_pretrained_model(
|
def _load_pretrained_model(
|
||||||
cls,
|
cls,
|
||||||
@@ -4430,27 +4435,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
|
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
|
||||||
|
|
||||||
def _fix_key(key):
|
|
||||||
if "beta" in key:
|
|
||||||
return key.replace("beta", "bias")
|
|
||||||
if "gamma" in key:
|
|
||||||
return key.replace("gamma", "weight")
|
|
||||||
|
|
||||||
# to avoid logging parametrized weight norm renaming
|
|
||||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
|
||||||
if "weight_g" in key:
|
|
||||||
return key.replace("weight_g", "parametrizations.weight.original0")
|
|
||||||
if "weight_v" in key:
|
|
||||||
return key.replace("weight_v", "parametrizations.weight.original1")
|
|
||||||
else:
|
|
||||||
if "parametrizations.weight.original0" in key:
|
|
||||||
return key.replace("parametrizations.weight.original0", "weight_g")
|
|
||||||
if "parametrizations.weight.original1" in key:
|
|
||||||
return key.replace("parametrizations.weight.original1", "weight_v")
|
|
||||||
return key
|
|
||||||
|
|
||||||
original_loaded_keys = loaded_keys
|
original_loaded_keys = loaded_keys
|
||||||
loaded_keys = [_fix_key(key) for key in loaded_keys]
|
loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys]
|
||||||
|
|
||||||
if len(prefix) > 0:
|
if len(prefix) > 0:
|
||||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
||||||
@@ -4615,23 +4601,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
state_dict,
|
state_dict,
|
||||||
model_state_dict,
|
model_state_dict,
|
||||||
loaded_keys,
|
loaded_keys,
|
||||||
|
original_loaded_keys,
|
||||||
add_prefix_to_model,
|
add_prefix_to_model,
|
||||||
remove_prefix_from_model,
|
remove_prefix_from_model,
|
||||||
ignore_mismatched_sizes,
|
ignore_mismatched_sizes,
|
||||||
):
|
):
|
||||||
mismatched_keys = []
|
mismatched_keys = []
|
||||||
if ignore_mismatched_sizes:
|
if ignore_mismatched_sizes:
|
||||||
for checkpoint_key in loaded_keys:
|
for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
|
||||||
# If the checkpoint is sharded, we may not have the key here.
|
# If the checkpoint is sharded, we may not have the key here.
|
||||||
if checkpoint_key not in state_dict:
|
if checkpoint_key not in state_dict:
|
||||||
continue
|
continue
|
||||||
model_key = checkpoint_key
|
|
||||||
if remove_prefix_from_model:
|
if remove_prefix_from_model:
|
||||||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
||||||
model_key = f"{prefix}.{checkpoint_key}"
|
model_key = f"{prefix}.{model_key}"
|
||||||
elif add_prefix_to_model:
|
elif add_prefix_to_model:
|
||||||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
||||||
model_key = ".".join(checkpoint_key.split(".")[1:])
|
model_key = ".".join(model_key.split(".")[1:])
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model_key in model_state_dict
|
model_key in model_state_dict
|
||||||
@@ -4680,6 +4666,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
mismatched_keys = _find_mismatched_keys(
|
mismatched_keys = _find_mismatched_keys(
|
||||||
state_dict,
|
state_dict,
|
||||||
model_state_dict,
|
model_state_dict,
|
||||||
|
loaded_keys,
|
||||||
original_loaded_keys,
|
original_loaded_keys,
|
||||||
add_prefix_to_model,
|
add_prefix_to_model,
|
||||||
remove_prefix_from_model,
|
remove_prefix_from_model,
|
||||||
@@ -4688,9 +4675,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# For GGUF models `state_dict` is never set to None as the state dict is always small
|
# For GGUF models `state_dict` is never set to None as the state dict is always small
|
||||||
if gguf_path:
|
if gguf_path:
|
||||||
|
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||||
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
state_dict,
|
fixed_state_dict,
|
||||||
start_prefix,
|
start_prefix,
|
||||||
expected_keys,
|
expected_keys,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
@@ -4709,8 +4697,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
assign_to_params_buffers = check_support_param_buffer_assignment(
|
assign_to_params_buffers = check_support_param_buffer_assignment(
|
||||||
model_to_load, state_dict, start_prefix
|
model_to_load, state_dict, start_prefix
|
||||||
)
|
)
|
||||||
|
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||||
error_msgs = _load_state_dict_into_model(
|
error_msgs = _load_state_dict_into_model(
|
||||||
model_to_load, state_dict, start_prefix, assign_to_params_buffers
|
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -4761,6 +4750,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
mismatched_keys += _find_mismatched_keys(
|
mismatched_keys += _find_mismatched_keys(
|
||||||
state_dict,
|
state_dict,
|
||||||
model_state_dict,
|
model_state_dict,
|
||||||
|
loaded_keys,
|
||||||
original_loaded_keys,
|
original_loaded_keys,
|
||||||
add_prefix_to_model,
|
add_prefix_to_model,
|
||||||
remove_prefix_from_model,
|
remove_prefix_from_model,
|
||||||
@@ -4774,9 +4764,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
|
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
state_dict,
|
fixed_state_dict,
|
||||||
start_prefix,
|
start_prefix,
|
||||||
expected_keys,
|
expected_keys,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
@@ -4797,8 +4788,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
assign_to_params_buffers = check_support_param_buffer_assignment(
|
assign_to_params_buffers = check_support_param_buffer_assignment(
|
||||||
model_to_load, state_dict, start_prefix
|
model_to_load, state_dict, start_prefix
|
||||||
)
|
)
|
||||||
|
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||||
error_msgs += _load_state_dict_into_model(
|
error_msgs += _load_state_dict_into_model(
|
||||||
model_to_load, state_dict, start_prefix, assign_to_params_buffers
|
model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
|
||||||
)
|
)
|
||||||
|
|
||||||
# force memory release
|
# force memory release
|
||||||
@@ -4930,9 +4922,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
|
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
|
||||||
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
|
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
|
||||||
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
|
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
|
||||||
|
fixed_state_dict = model._fix_state_dict_keys_on_load(state_dict)
|
||||||
error_msgs = _load_state_dict_into_meta_model(
|
error_msgs = _load_state_dict_into_meta_model(
|
||||||
model,
|
model,
|
||||||
state_dict,
|
fixed_state_dict,
|
||||||
start_prefix,
|
start_prefix,
|
||||||
expected_keys=expected_keys,
|
expected_keys=expected_keys,
|
||||||
hf_quantizer=hf_quantizer,
|
hf_quantizer=hf_quantizer,
|
||||||
|
|||||||
@@ -249,6 +249,7 @@ from . import (
|
|||||||
time_series_transformer,
|
time_series_transformer,
|
||||||
timesformer,
|
timesformer,
|
||||||
timm_backbone,
|
timm_backbone,
|
||||||
|
timm_wrapper,
|
||||||
trocr,
|
trocr,
|
||||||
tvp,
|
tvp,
|
||||||
udop,
|
udop,
|
||||||
|
|||||||
@@ -276,6 +276,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
||||||
("timesformer", "TimesformerConfig"),
|
("timesformer", "TimesformerConfig"),
|
||||||
("timm_backbone", "TimmBackboneConfig"),
|
("timm_backbone", "TimmBackboneConfig"),
|
||||||
|
("timm_wrapper", "TimmWrapperConfig"),
|
||||||
("trajectory_transformer", "TrajectoryTransformerConfig"),
|
("trajectory_transformer", "TrajectoryTransformerConfig"),
|
||||||
("transfo-xl", "TransfoXLConfig"),
|
("transfo-xl", "TransfoXLConfig"),
|
||||||
("trocr", "TrOCRConfig"),
|
("trocr", "TrOCRConfig"),
|
||||||
@@ -599,6 +600,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("time_series_transformer", "Time Series Transformer"),
|
("time_series_transformer", "Time Series Transformer"),
|
||||||
("timesformer", "TimeSformer"),
|
("timesformer", "TimeSformer"),
|
||||||
("timm_backbone", "TimmBackbone"),
|
("timm_backbone", "TimmBackbone"),
|
||||||
|
("timm_wrapper", "TimmWrapperModel"),
|
||||||
("trajectory_transformer", "Trajectory Transformer"),
|
("trajectory_transformer", "Trajectory Transformer"),
|
||||||
("transfo-xl", "Transformer-XL"),
|
("transfo-xl", "Transformer-XL"),
|
||||||
("trocr", "TrOCR"),
|
("trocr", "TrOCR"),
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ from ...utils import (
|
|||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
IMAGE_PROCESSOR_NAME,
|
IMAGE_PROCESSOR_NAME,
|
||||||
get_file_from_repo,
|
get_file_from_repo,
|
||||||
|
is_timm_config_dict,
|
||||||
|
is_timm_local_checkpoint,
|
||||||
is_torchvision_available,
|
is_torchvision_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -137,6 +139,7 @@ else:
|
|||||||
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||||
("table-transformer", ("DetrImageProcessor",)),
|
("table-transformer", ("DetrImageProcessor",)),
|
||||||
("timesformer", ("VideoMAEImageProcessor",)),
|
("timesformer", ("VideoMAEImageProcessor",)),
|
||||||
|
("timm_wrapper", ("TimmWrapperImageProcessor",)),
|
||||||
("tvlt", ("TvltImageProcessor",)),
|
("tvlt", ("TvltImageProcessor",)),
|
||||||
("tvp", ("TvpImageProcessor",)),
|
("tvp", ("TvpImageProcessor",)),
|
||||||
("udop", ("LayoutLMv3ImageProcessor",)),
|
("udop", ("LayoutLMv3ImageProcessor",)),
|
||||||
@@ -376,6 +379,8 @@ class AutoImageProcessor:
|
|||||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||||
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
||||||
execute code present on the Hub on your local machine.
|
execute code present on the Hub on your local machine.
|
||||||
|
image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
|
||||||
|
The name of the file in the model directory to use for the image processor config.
|
||||||
kwargs (`Dict[str, Any]`, *optional*):
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
The values in kwargs of any keys which are image processor attributes will be used to override the
|
The values in kwargs of any keys which are image processor attributes will be used to override the
|
||||||
loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
|
loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
|
||||||
@@ -415,7 +420,37 @@ class AutoImageProcessor:
|
|||||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||||
kwargs["_from_auto"] = True
|
kwargs["_from_auto"] = True
|
||||||
|
|
||||||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
# Resolve the image processor config filename
|
||||||
|
if "image_processor_filename" in kwargs:
|
||||||
|
image_processor_filename = kwargs.pop("image_processor_filename")
|
||||||
|
elif is_timm_local_checkpoint(pretrained_model_name_or_path):
|
||||||
|
image_processor_filename = CONFIG_NAME
|
||||||
|
else:
|
||||||
|
image_processor_filename = IMAGE_PROCESSOR_NAME
|
||||||
|
|
||||||
|
# Load the image processor config
|
||||||
|
try:
|
||||||
|
# Main path for all transformers models and local TimmWrapper checkpoints
|
||||||
|
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
||||||
|
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
|
||||||
|
)
|
||||||
|
except Exception as initial_exception:
|
||||||
|
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
|
||||||
|
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
|
||||||
|
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
|
||||||
|
# load `config.json` and if it fails with some error, we raise the initial exception.
|
||||||
|
try:
|
||||||
|
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
||||||
|
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise initial_exception
|
||||||
|
|
||||||
|
# In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
|
||||||
|
# because only timm models have image processing in `config.json`.
|
||||||
|
if not is_timm_config_dict(config_dict):
|
||||||
|
raise initial_exception
|
||||||
|
|
||||||
image_processor_class = config_dict.get("image_processor_type", None)
|
image_processor_class = config_dict.get("image_processor_type", None)
|
||||||
image_processor_auto_map = None
|
image_processor_auto_map = None
|
||||||
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
||||||
|
|||||||
@@ -255,6 +255,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("time_series_transformer", "TimeSeriesTransformerModel"),
|
("time_series_transformer", "TimeSeriesTransformerModel"),
|
||||||
("timesformer", "TimesformerModel"),
|
("timesformer", "TimesformerModel"),
|
||||||
("timm_backbone", "TimmBackbone"),
|
("timm_backbone", "TimmBackbone"),
|
||||||
|
("timm_wrapper", "TimmWrapperModel"),
|
||||||
("trajectory_transformer", "TrajectoryTransformerModel"),
|
("trajectory_transformer", "TrajectoryTransformerModel"),
|
||||||
("transfo-xl", "TransfoXLModel"),
|
("transfo-xl", "TransfoXLModel"),
|
||||||
("tvlt", "TvltModel"),
|
("tvlt", "TvltModel"),
|
||||||
@@ -605,6 +606,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
|
|||||||
("table-transformer", "TableTransformerModel"),
|
("table-transformer", "TableTransformerModel"),
|
||||||
("timesformer", "TimesformerModel"),
|
("timesformer", "TimesformerModel"),
|
||||||
("timm_backbone", "TimmBackbone"),
|
("timm_backbone", "TimmBackbone"),
|
||||||
|
("timm_wrapper", "TimmWrapperModel"),
|
||||||
("van", "VanModel"),
|
("van", "VanModel"),
|
||||||
("videomae", "VideoMAEModel"),
|
("videomae", "VideoMAEModel"),
|
||||||
("vit", "ViTModel"),
|
("vit", "ViTModel"),
|
||||||
@@ -690,6 +692,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("swiftformer", "SwiftFormerForImageClassification"),
|
("swiftformer", "SwiftFormerForImageClassification"),
|
||||||
("swin", "SwinForImageClassification"),
|
("swin", "SwinForImageClassification"),
|
||||||
("swinv2", "Swinv2ForImageClassification"),
|
("swinv2", "Swinv2ForImageClassification"),
|
||||||
|
("timm_wrapper", "TimmWrapperForImageClassification"),
|
||||||
("van", "VanForImageClassification"),
|
("van", "VanForImageClassification"),
|
||||||
("vit", "ViTForImageClassification"),
|
("vit", "ViTForImageClassification"),
|
||||||
("vit_hybrid", "ViTHybridForImageClassification"),
|
("vit_hybrid", "ViTHybridForImageClassification"),
|
||||||
|
|||||||
28
src/transformers/models/timm_wrapper/__init__.py
Normal file
28
src/transformers/models/timm_wrapper/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_timm_wrapper import *
|
||||||
|
from .modeling_timm_wrapper import *
|
||||||
|
from .processing_timm_wrapper import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Configuration for TimmWrapper models"""
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TimmWrapperConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`].
|
||||||
|
|
||||||
|
It is used to instantiate a timm model according to the specified arguments, defining the model.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
do_pooling (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> from transformers import TimmWrapperModel
|
||||||
|
|
||||||
|
>>> # Initializing a timm model
|
||||||
|
>>> model = TimmWrapperModel.from_pretrained("timm/resnet18.a1_in1k")
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "timm_wrapper"
|
||||||
|
|
||||||
|
def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs):
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.do_pooling = do_pooling
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
|
||||||
|
# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
|
||||||
|
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config
|
||||||
|
# and to avoid duplicate attributes
|
||||||
|
|
||||||
|
num_labels_in_kwargs = kwargs.pop("num_labels", None)
|
||||||
|
num_labels_in_dict = config_dict.pop("num_classes", None)
|
||||||
|
|
||||||
|
# passed num_labels has priority over num_classes in config_dict
|
||||||
|
kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict
|
||||||
|
|
||||||
|
# pop num_classes from "pretrained_cfg",
|
||||||
|
# it is not necessary to have it, only root one is used in timm
|
||||||
|
if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]:
|
||||||
|
config_dict["pretrained_cfg"].pop("num_classes", None)
|
||||||
|
|
||||||
|
return super().from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
output = super().to_dict()
|
||||||
|
output["num_classes"] = self.num_labels
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["TimmWrapperConfig"]
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||||
|
from ...image_transforms import to_pil_image
|
||||||
|
from ...image_utils import ImageInput, make_list_of_images
|
||||||
|
from ...utils import TensorType, logging, requires_backends
|
||||||
|
from ...utils.import_utils import is_timm_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_timm_available():
|
||||||
|
import timm
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TimmWrapperImageProcessor(BaseImageProcessor):
|
||||||
|
"""
|
||||||
|
Wrapper class for timm models to be used within transformers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_cfg (`Dict[str, Any]`):
|
||||||
|
The configuration of the pretrained model used to resolve evaluation and
|
||||||
|
training transforms.
|
||||||
|
architecture (`Optional[str]`, *optional*):
|
||||||
|
Name of the architecture of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pretrained_cfg: Dict[str, Any],
|
||||||
|
architecture: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
requires_backends(self, "timm")
|
||||||
|
super().__init__(architecture=architecture)
|
||||||
|
|
||||||
|
self.data_config = timm.data.resolve_data_config(pretrained_cfg, model=None, verbose=False)
|
||||||
|
self.val_transforms = timm.data.create_transform(**self.data_config, is_training=False)
|
||||||
|
|
||||||
|
# useful for training, see examples/pytorch/image-classification/run_image_classification.py
|
||||||
|
self.train_transforms = timm.data.create_transform(**self.data_config, is_training=True)
|
||||||
|
|
||||||
|
# If `ToTensor` is in the transforms, then the input should be numpy array or PIL image.
|
||||||
|
# Otherwise, the input can be a tensor. In later timm versions, `MaybeToTensor` is used
|
||||||
|
# which can handle both numpy arrays / PIL images and tensors.
|
||||||
|
self._not_supports_tensor_input = any(
|
||||||
|
transform.__class__.__name__ == "ToTensor" for transform in self.val_transforms.transforms
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serializes this instance to a Python dictionary.
|
||||||
|
"""
|
||||||
|
output = super().to_dict()
|
||||||
|
output.pop("train_transforms", None)
|
||||||
|
output.pop("val_transforms", None)
|
||||||
|
output.pop("_not_supports_tensor_input", None)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_image_processor_dict(
|
||||||
|
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||||
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get the image processor dict for the model.
|
||||||
|
"""
|
||||||
|
image_processor_filename = kwargs.pop("image_processor_filename", "config.json")
|
||||||
|
return super().get_image_processor_dict(
|
||||||
|
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess. Expects a single or batch of images
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return.
|
||||||
|
"""
|
||||||
|
if return_tensors != "pt":
|
||||||
|
raise ValueError(f"return_tensors for TimmWrapperImageProcessor must be 'pt', but got {return_tensors}")
|
||||||
|
|
||||||
|
if self._not_supports_tensor_input and isinstance(images, torch.Tensor):
|
||||||
|
images = images.cpu().numpy()
|
||||||
|
|
||||||
|
# If the input is a torch tensor, then no conversion is needed
|
||||||
|
# Otherwise, we need to pass in a list of PIL images
|
||||||
|
if isinstance(images, torch.Tensor):
|
||||||
|
images = self.val_transforms(images)
|
||||||
|
# Add batch dimension if a single image
|
||||||
|
images = images.unsqueeze(0) if images.ndim == 3 else images
|
||||||
|
else:
|
||||||
|
images = make_list_of_images(images)
|
||||||
|
images = [to_pil_image(image) for image in images]
|
||||||
|
images = torch.stack([self.val_transforms(image) for image in images])
|
||||||
|
|
||||||
|
return BatchFeature({"pixel_values": images}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def save_pretrained(self, *args, **kwargs):
|
||||||
|
# disable it to make checkpoint the same as in `timm` library.
|
||||||
|
logger.warning_once(
|
||||||
|
"The `save_pretrained` method is disabled for TimmWrapperImageProcessor. "
|
||||||
|
"The image processor configuration is saved directly in `config.json` when "
|
||||||
|
"`save_pretrained` is called for saving the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["TimmWrapperImageProcessor"]
|
||||||
363
src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Normal file
363
src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
|
from ...modeling_outputs import ImageClassifierOutput, ModelOutput
|
||||||
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_timm_available,
|
||||||
|
replace_return_docstrings,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
|
from .configuration_timm_wrapper import TimmWrapperConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_timm_available():
|
||||||
|
import timm
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TimmWrapperModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Output class for models TimmWrapperModel, containing the last hidden states, an optional pooled output,
|
||||||
|
and optional hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (`torch.FloatTensor`):
|
||||||
|
The last hidden state of the model, output before applying the classification head.
|
||||||
|
pooler_output (`torch.FloatTensor`, *optional*):
|
||||||
|
The pooled output derived from the last hidden state, if applicable.
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
||||||
|
A tuple containing the intermediate hidden states of the model at the output of each layer or specified layers.
|
||||||
|
Returned if `output_hidden_states=True` is set or if `config.output_hidden_states=True`.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*):
|
||||||
|
A tuple containing the intermediate attention weights of the model at the output of each layer.
|
||||||
|
Returned if `output_attentions=True` is set or if `config.output_attentions=True`.
|
||||||
|
Note: Currently, Timm models do not support attentions output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: torch.FloatTensor
|
||||||
|
pooler_output: Optional[torch.FloatTensor] = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
|
TIMM_WRAPPER_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`TimmWrapperImageProcessor.preprocess`]
|
||||||
|
for details.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments passed along to the `timm` model forward.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TimmWrapperPreTrainedModel(PreTrainedModel):
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
config_class = TimmWrapperConfig
|
||||||
|
_no_split_modules = []
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["vision", "timm"])
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_state_dict_key_on_load(key):
|
||||||
|
"""
|
||||||
|
Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
|
||||||
|
We don't want this behavior for timm wrapped models. Instead, this method adds a
|
||||||
|
"timm_model." prefix to enable loading official timm Hub checkpoints.
|
||||||
|
"""
|
||||||
|
if "timm_model." not in key:
|
||||||
|
return f"timm_model.{key}"
|
||||||
|
return key
|
||||||
|
|
||||||
|
def _fix_state_dict_key_on_save(self, key):
|
||||||
|
"""
|
||||||
|
Overrides original method to remove "timm_model." prefix from state_dict keys.
|
||||||
|
Makes the saved checkpoint compatible with the `timm` library.
|
||||||
|
"""
|
||||||
|
return key.replace("timm_model.", "")
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Override original method to fix state_dict keys on load for cases when weights are loaded
|
||||||
|
without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint).
|
||||||
|
"""
|
||||||
|
state_dict = self._fix_state_dict_keys_on_load(state_dict)
|
||||||
|
return super().load_state_dict(state_dict, *args, **kwargs)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""
|
||||||
|
Initialize weights function to properly initialize Linear layer weights.
|
||||||
|
Since model architectures may vary, we assume only the classifier requires
|
||||||
|
initialization, while all other weights should be loaded from the checkpoint.
|
||||||
|
"""
|
||||||
|
if isinstance(module, (nn.Linear)):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
class TimmWrapperModel(TimmWrapperPreTrainedModel):
|
||||||
|
"""
|
||||||
|
Wrapper class for timm models to be used in transformers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: TimmWrapperConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
# using num_classes=0 to avoid creating classification head
|
||||||
|
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=TimmWrapperModelOutput, config_class=TimmWrapperConfig)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[Union[bool, List[int]]] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
do_pooling: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[TimmWrapperModelOutput, Tuple[Tensor, ...]]:
|
||||||
|
r"""
|
||||||
|
do_pooling (`bool`, *optional*):
|
||||||
|
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the
|
||||||
|
`do_pooling` value from the config is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> from urllib.request import urlopen
|
||||||
|
>>> from transformers import AutoModel, AutoImageProcessor
|
||||||
|
|
||||||
|
>>> # Load image
|
||||||
|
>>> image = Image.open(urlopen(
|
||||||
|
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
|
||||||
|
... ))
|
||||||
|
|
||||||
|
>>> # Load model and image processor
|
||||||
|
>>> checkpoint = "timm/resnet50.a1_in1k"
|
||||||
|
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||||
|
>>> model = AutoModel.from_pretrained(checkpoint).eval()
|
||||||
|
|
||||||
|
>>> # Preprocess image
|
||||||
|
>>> inputs = image_processor(image)
|
||||||
|
|
||||||
|
>>> # Forward pass
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
... outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> # Get pooled output
|
||||||
|
>>> pooled_output = outputs.pooler_output
|
||||||
|
|
||||||
|
>>> # Get last hidden state
|
||||||
|
>>> last_hidden_state = outputs.last_hidden_state
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
do_pooling = do_pooling if do_pooling is not None else self.config.do_pooling
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
raise ValueError("Cannot set `output_attentions` for timm models.")
|
||||||
|
|
||||||
|
if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"):
|
||||||
|
raise ValueError(
|
||||||
|
"The 'output_hidden_states' option cannot be set for this timm model. "
|
||||||
|
"To enable this feature, the 'forward_intermediates' method must be implemented "
|
||||||
|
"in the timm model (available in timm versions > 1.*). Please consider using a "
|
||||||
|
"different architecture or updating the timm package to a compatible version."
|
||||||
|
)
|
||||||
|
|
||||||
|
pixel_values = pixel_values.to(self.device, self.dtype)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
# to enable hidden states selection
|
||||||
|
if isinstance(output_hidden_states, (list, tuple)):
|
||||||
|
kwargs["indices"] = output_hidden_states
|
||||||
|
last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs)
|
||||||
|
else:
|
||||||
|
last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs)
|
||||||
|
hidden_states = None
|
||||||
|
|
||||||
|
if do_pooling:
|
||||||
|
# classification head is not created, applying pooling only
|
||||||
|
pooler_output = self.timm_model.forward_head(last_hidden_state)
|
||||||
|
else:
|
||||||
|
pooler_output = None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
outputs = (last_hidden_state, pooler_output, hidden_states)
|
||||||
|
outputs = tuple(output for output in outputs if output is not None)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return TimmWrapperModelOutput(
|
||||||
|
last_hidden_state=last_hidden_state,
|
||||||
|
pooler_output=pooler_output,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
|
||||||
|
"""
|
||||||
|
Wrapper class for timm models to be used in transformers for image classification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: TimmWrapperConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if config.num_labels == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"You are trying to load weights into `TimmWrapperForImageClassification` from a checkpoint with no classifier head. "
|
||||||
|
"Please specify the number of classes, e.g. `model = TimmWrapperForImageClassification.from_pretrained(..., num_labels=10)`, "
|
||||||
|
"or use `TimmWrapperModel` for feature extraction."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=TimmWrapperConfig)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[Union[bool, List[int]]] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[ImageClassifierOutput, Tuple[Tensor, ...]]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> from urllib.request import urlopen
|
||||||
|
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor
|
||||||
|
|
||||||
|
>>> # Load image
|
||||||
|
>>> image = Image.open(urlopen(
|
||||||
|
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
|
||||||
|
... ))
|
||||||
|
|
||||||
|
>>> # Load model and image processor
|
||||||
|
>>> checkpoint = "timm/resnet50.a1_in1k"
|
||||||
|
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||||
|
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
|
||||||
|
|
||||||
|
>>> # Preprocess image
|
||||||
|
>>> inputs = image_processor(image)
|
||||||
|
|
||||||
|
>>> # Forward pass
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
... logits = model(**inputs).logits
|
||||||
|
|
||||||
|
>>> # Get top 5 predictions
|
||||||
|
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
raise ValueError("Cannot set `output_attentions` for timm models.")
|
||||||
|
|
||||||
|
if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"):
|
||||||
|
raise ValueError(
|
||||||
|
"The 'output_hidden_states' option cannot be set for this timm model. "
|
||||||
|
"To enable this feature, the 'forward_intermediates' method must be implemented "
|
||||||
|
"in the timm model (available in timm versions > 1.*). Please consider using a "
|
||||||
|
"different architecture or updating the timm package to a compatible version."
|
||||||
|
)
|
||||||
|
|
||||||
|
pixel_values = pixel_values.to(self.device, self.dtype)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
# to enable hidden states selection
|
||||||
|
if isinstance(output_hidden_states, (list, tuple)):
|
||||||
|
kwargs["indices"] = output_hidden_states
|
||||||
|
last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs)
|
||||||
|
logits = self.timm_model.forward_head(last_hidden_state)
|
||||||
|
else:
|
||||||
|
logits = self.timm_model(pixel_values, **kwargs)
|
||||||
|
hidden_states = None
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
if self.config.problem_type is None:
|
||||||
|
if self.num_labels == 1:
|
||||||
|
self.config.problem_type = "regression"
|
||||||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
|
self.config.problem_type = "single_label_classification"
|
||||||
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
if self.num_labels == 1:
|
||||||
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||||
|
else:
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
elif self.config.problem_type == "single_label_classification":
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
outputs = (loss, logits, hidden_states)
|
||||||
|
outputs = tuple(output for output in outputs if output is not None)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return ImageClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["TimmWrapperPreTrainedModel", "TimmWrapperModel", "TimmWrapperForImageClassification"]
|
||||||
@@ -55,6 +55,8 @@ from .generic import (
|
|||||||
is_tensor,
|
is_tensor,
|
||||||
is_tf_symbolic_tensor,
|
is_tf_symbolic_tensor,
|
||||||
is_tf_tensor,
|
is_tf_tensor,
|
||||||
|
is_timm_config_dict,
|
||||||
|
is_timm_local_checkpoint,
|
||||||
is_torch_device,
|
is_torch_device,
|
||||||
is_torch_dtype,
|
is_torch_dtype,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
|
|||||||
@@ -9043,6 +9043,27 @@ class TimmBackbone(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class TimmWrapperForImageClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class TimmWrapperModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class TimmWrapperPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class TrOCRForCausalLM(metaclass=DummyObject):
|
class TrOCRForCausalLM(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||||
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class TimmWrapperImageProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["timm", "torchvision"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["timm", "torchvision"])
|
||||||
@@ -16,6 +16,8 @@ Generic utilities
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
@@ -24,7 +26,7 @@ from contextlib import ExitStack, contextmanager
|
|||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict
|
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -867,3 +869,36 @@ class LossKwargs(TypedDict, total=False):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
num_items_in_batch: Optional[int]
|
num_items_in_batch: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool:
|
||||||
|
"""Checks whether a config dict is a timm config dict."""
|
||||||
|
return "pretrained_cfg" in config_dict
|
||||||
|
|
||||||
|
|
||||||
|
def is_timm_local_checkpoint(pretrained_model_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether a checkpoint is a timm model checkpoint.
|
||||||
|
"""
|
||||||
|
if pretrained_model_path is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# in case it's Path, not str
|
||||||
|
pretrained_model_path = str(pretrained_model_path)
|
||||||
|
|
||||||
|
is_file = os.path.isfile(pretrained_model_path)
|
||||||
|
is_dir = os.path.isdir(pretrained_model_path)
|
||||||
|
|
||||||
|
# pretrained_model_path is a file
|
||||||
|
if is_file and pretrained_model_path.endswith(".json"):
|
||||||
|
with open(pretrained_model_path, "r") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
return is_timm_config_dict(config_dict)
|
||||||
|
|
||||||
|
# pretrained_model_path is a directory with a config.json
|
||||||
|
if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")):
|
||||||
|
with open(os.path.join(pretrained_model_path, "config.json"), "r") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
return is_timm_config_dict(config_dict)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|||||||
0
tests/models/timm_wrapper/__init__.py
Normal file
0
tests/models/timm_wrapper/__init__.py
Normal file
103
tests/models/timm_wrapper/test_image_processing_timm_wrapper.py
Normal file
103
tests/models/timm_wrapper/test_image_processing_timm_wrapper.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.testing_utils import require_torch, require_torchvision, require_vision
|
||||||
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from transformers import TimmWrapperConfig, TimmWrapperImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
@require_torchvision
|
||||||
|
class TimmWrapperImageProcessingTest(unittest.TestCase):
|
||||||
|
image_processing_class = TimmWrapperImageProcessor if is_vision_available() else None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
config = TimmWrapperConfig.from_pretrained("timm/resnet18.a1_in1k")
|
||||||
|
config.save_pretrained(self.temp_dir.name)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
|
||||||
|
def test_load_from_hub(self):
|
||||||
|
image_processor = TimmWrapperImageProcessor.from_pretrained("timm/resnet18.a1_in1k")
|
||||||
|
self.assertIsInstance(image_processor, TimmWrapperImageProcessor)
|
||||||
|
|
||||||
|
def test_load_from_local_dir(self):
|
||||||
|
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
|
||||||
|
self.assertIsInstance(image_processor, TimmWrapperImageProcessor)
|
||||||
|
|
||||||
|
def test_image_processor_properties(self):
|
||||||
|
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
|
||||||
|
self.assertTrue(hasattr(image_processor, "data_config"))
|
||||||
|
self.assertTrue(hasattr(image_processor, "val_transforms"))
|
||||||
|
self.assertTrue(hasattr(image_processor, "train_transforms"))
|
||||||
|
|
||||||
|
def test_image_processor_call_numpy(self):
|
||||||
|
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
|
||||||
|
|
||||||
|
single_image = np.random.randint(256, size=(256, 256, 3), dtype=np.uint8)
|
||||||
|
batch_images = [single_image, single_image, single_image]
|
||||||
|
|
||||||
|
# single image
|
||||||
|
pixel_values = image_processor(single_image).pixel_values
|
||||||
|
self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
|
||||||
|
|
||||||
|
# batch images
|
||||||
|
pixel_values = image_processor(batch_images).pixel_values
|
||||||
|
self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
|
||||||
|
|
||||||
|
def test_image_processor_call_pil(self):
|
||||||
|
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
|
||||||
|
|
||||||
|
single_image = Image.fromarray(np.random.randint(256, size=(256, 256, 3), dtype=np.uint8))
|
||||||
|
batch_images = [single_image, single_image, single_image]
|
||||||
|
|
||||||
|
# single image
|
||||||
|
pixel_values = image_processor(single_image).pixel_values
|
||||||
|
self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
|
||||||
|
|
||||||
|
# batch images
|
||||||
|
pixel_values = image_processor(batch_images).pixel_values
|
||||||
|
self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
|
||||||
|
|
||||||
|
def test_image_processor_call_tensor(self):
|
||||||
|
image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
|
||||||
|
|
||||||
|
single_image = torch.from_numpy(np.random.randint(256, size=(3, 256, 256), dtype=np.uint8)).float()
|
||||||
|
batch_images = [single_image, single_image, single_image]
|
||||||
|
|
||||||
|
# single image
|
||||||
|
pixel_values = image_processor(single_image).pixel_values
|
||||||
|
self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
|
||||||
|
|
||||||
|
# batch images
|
||||||
|
pixel_values = image_processor(batch_images).pixel_values
|
||||||
|
self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
|
||||||
366
tests/models/timm_wrapper/test_modeling_timm_wrapper.py
Normal file
366
tests/models/timm_wrapper/test_modeling_timm_wrapper.py
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_bitsandbytes,
|
||||||
|
require_timm,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
from transformers.utils.import_utils import is_timm_available, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import TimmWrapperConfig, TimmWrapperForImageClassification, TimmWrapperModel
|
||||||
|
|
||||||
|
|
||||||
|
if is_timm_available():
|
||||||
|
import timm
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from transformers import TimmWrapperImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class TimmWrapperModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
model_name="timm/resnet18.a1_in1k",
|
||||||
|
batch_size=3,
|
||||||
|
image_size=32,
|
||||||
|
num_channels=3,
|
||||||
|
is_training=True,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.model_name = model_name
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.image_size = image_size
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return TimmWrapperConfig.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, pixel_values):
|
||||||
|
model = TimmWrapperModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.feature_map[-1].shape,
|
||||||
|
(self.batch_size, model.channels[-1], 14, 14),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {"pixel_values": pixel_values}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_timm
|
||||||
|
class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (TimmWrapperModel, TimmWrapperForImageClassification) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{"image-feature-extraction": TimmWrapperModel, "image-classification": TimmWrapperForImageClassification}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_pruning = False
|
||||||
|
has_attentions = False
|
||||||
|
test_model_parallel = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.config_class = TimmWrapperConfig
|
||||||
|
self.model_tester = TimmWrapperModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(
|
||||||
|
self,
|
||||||
|
config_class=self.config_class,
|
||||||
|
has_text_modality=False,
|
||||||
|
common_properties=[],
|
||||||
|
model_name="timm/resnet18.a1_in1k",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
# check all hidden states
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs_dict, output_hidden_states=True)
|
||||||
|
self.assertTrue(
|
||||||
|
len(outputs.hidden_states) == 5, f"expected 5 hidden states, but got {len(outputs.hidden_states)}"
|
||||||
|
)
|
||||||
|
expected_shapes = [[16, 16], [8, 8], [4, 4], [2, 2], [1, 1]]
|
||||||
|
resulted_shapes = [list(h.shape[2:]) for h in outputs.hidden_states]
|
||||||
|
self.assertListEqual(expected_shapes, resulted_shapes)
|
||||||
|
|
||||||
|
# check we can select hidden states by indices
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs_dict, output_hidden_states=[-2, -1])
|
||||||
|
self.assertTrue(
|
||||||
|
len(outputs.hidden_states) == 2, f"expected 2 hidden states, but got {len(outputs.hidden_states)}"
|
||||||
|
)
|
||||||
|
expected_shapes = [[2, 2], [1, 1]]
|
||||||
|
resulted_shapes = [list(h.shape[2:]) for h in outputs.hidden_states]
|
||||||
|
self.assertListEqual(expected_shapes, resulted_shapes)
|
||||||
|
|
||||||
|
@unittest.skip(reason="TimmWrapper models doesn't have inputs_embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="TimmWrapper models doesn't have inputs_embeds")
|
||||||
|
def test_model_get_set_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="TimmWrapper doesn't support output_attentions=True.")
|
||||||
|
def test_torchscript_output_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="TimmWrapper doesn't support this.")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="TimmWrapper initialization is managed on the timm side")
|
||||||
|
def test_initialization(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Need to use a timm model and there is no tiny model available.")
|
||||||
|
def test_model_is_small(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["pixel_values"]
|
||||||
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
|
def test_do_pooling_option(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.do_pooling = False
|
||||||
|
|
||||||
|
model = TimmWrapperModel._from_config(config)
|
||||||
|
|
||||||
|
# check there is no pooling
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict)
|
||||||
|
self.assertIsNone(output.pooler_output)
|
||||||
|
|
||||||
|
# check there is pooler output
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict, do_pooling=True)
|
||||||
|
self.assertIsNotNone(output.pooler_output)
|
||||||
|
|
||||||
|
|
||||||
|
# We will verify our results on an image of cute cats
|
||||||
|
def prepare_img():
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_timm
|
||||||
|
@require_vision
|
||||||
|
class TimmWrapperModelIntegrationTest(unittest.TestCase):
|
||||||
|
# some popular ones
|
||||||
|
model_names_to_test = [
|
||||||
|
"vit_small_patch16_384.augreg_in21k_ft_in1k",
|
||||||
|
"resnet50.a1_in1k",
|
||||||
|
"tf_mobilenetv3_large_minimal_100.in1k",
|
||||||
|
"swin_tiny_patch4_window7_224.ms_in1k",
|
||||||
|
"ese_vovnet19b_dw.ra_in1k",
|
||||||
|
"hrnet_w18.ms_aug_in1k",
|
||||||
|
]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_image_classification_head(self):
|
||||||
|
checkpoint = "timm/resnet18.a1_in1k"
|
||||||
|
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map=torch_device).eval()
|
||||||
|
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
# verify the shape and logits
|
||||||
|
expected_shape = torch.Size((1, 1000))
|
||||||
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|
||||||
|
expected_label = 281 # tabby cat
|
||||||
|
self.assertEqual(torch.argmax(outputs.logits).item(), expected_label)
|
||||||
|
|
||||||
|
expected_slice = torch.tensor([-11.2618, -9.6192, -10.3205]).to(torch_device)
|
||||||
|
resulted_slice = outputs.logits[0, :3]
|
||||||
|
is_close = torch.allclose(resulted_slice, expected_slice, atol=1e-3)
|
||||||
|
self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_inference_image_classification_quantized(self):
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
|
||||||
|
checkpoint = "timm/vit_small_patch16_384.augreg_in21k_ft_in1k"
|
||||||
|
|
||||||
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
model = TimmWrapperForImageClassification.from_pretrained(
|
||||||
|
checkpoint, quantization_config=quantization_config, device_map=torch_device
|
||||||
|
).eval()
|
||||||
|
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
# verify the shape and logits
|
||||||
|
expected_shape = torch.Size((1, 1000))
|
||||||
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|
||||||
|
expected_label = 281 # tabby cat
|
||||||
|
self.assertEqual(torch.argmax(outputs.logits).item(), expected_label)
|
||||||
|
|
||||||
|
expected_slice = torch.tensor([-2.4043, 1.4492, -0.5127]).to(outputs.logits.dtype)
|
||||||
|
resulted_slice = outputs.logits[0, :3].cpu()
|
||||||
|
is_close = torch.allclose(resulted_slice, expected_slice, atol=0.1)
|
||||||
|
self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_transformers_model_for_classification_is_equivalent_to_timm(self):
|
||||||
|
# check that wrapper logits are the same as timm model logits
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
|
||||||
|
for model_name in self.model_names_to_test:
|
||||||
|
checkpoint = f"timm/{model_name}"
|
||||||
|
|
||||||
|
with self.subTest(msg=model_name):
|
||||||
|
# prepare inputs
|
||||||
|
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
|
||||||
|
pixel_values = image_processor(images=image).pixel_values.to(torch_device)
|
||||||
|
|
||||||
|
# load models
|
||||||
|
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map=torch_device).eval()
|
||||||
|
timm_model = timm.create_model(model_name, pretrained=True).to(torch_device).eval()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = model(pixel_values)
|
||||||
|
timm_outputs = timm_model(pixel_values)
|
||||||
|
|
||||||
|
# check shape is the same
|
||||||
|
self.assertEqual(outputs.logits.shape, timm_outputs.shape)
|
||||||
|
|
||||||
|
# check logits are the same
|
||||||
|
diff = (outputs.logits - timm_outputs).max().item()
|
||||||
|
self.assertLess(diff, 1e-4)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_transformers_model_is_equivalent_to_timm(self):
|
||||||
|
# check that wrapper logits are the same as timm model logits
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
|
||||||
|
models_to_test = ["vit_small_patch16_224.dino"] + self.model_names_to_test
|
||||||
|
|
||||||
|
for model_name in models_to_test:
|
||||||
|
checkpoint = f"timm/{model_name}"
|
||||||
|
|
||||||
|
with self.subTest(msg=model_name):
|
||||||
|
# prepare inputs
|
||||||
|
image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
|
||||||
|
pixel_values = image_processor(images=image).pixel_values.to(torch_device)
|
||||||
|
|
||||||
|
# load models
|
||||||
|
model = TimmWrapperModel.from_pretrained(checkpoint, device_map=torch_device).eval()
|
||||||
|
timm_model = timm.create_model(model_name, pretrained=True, num_classes=0).to(torch_device).eval()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = model(pixel_values)
|
||||||
|
timm_outputs = timm_model(pixel_values)
|
||||||
|
|
||||||
|
# check shape is the same
|
||||||
|
self.assertEqual(outputs.pooler_output.shape, timm_outputs.shape)
|
||||||
|
|
||||||
|
# check logits are the same
|
||||||
|
diff = (outputs.pooler_output - timm_outputs).max().item()
|
||||||
|
self.assertLess(diff, 1e-4)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_save_load_to_timm(self):
|
||||||
|
# test that timm model can be loaded to transformers, saved and then loaded back into timm
|
||||||
|
|
||||||
|
model = TimmWrapperForImageClassification.from_pretrained(
|
||||||
|
"timm/resnet18.a1_in1k", num_labels=10, ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
# there is no direct way to load timm model from folder, use the same config + path to weights
|
||||||
|
timm_model = timm.create_model(
|
||||||
|
"resnet18", num_classes=10, checkpoint_path=f"{tmpdirname}/model.safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that all weights are the same after reload
|
||||||
|
different_weights = []
|
||||||
|
for (name1, param1), (name2, param2) in zip(
|
||||||
|
model.timm_model.named_parameters(), timm_model.named_parameters()
|
||||||
|
):
|
||||||
|
if param1.shape != param2.shape or not torch.equal(param1, param2):
|
||||||
|
different_weights.append((name1, name2))
|
||||||
|
|
||||||
|
if different_weights:
|
||||||
|
self.fail(f"Found different weights after reloading: {different_weights}")
|
||||||
@@ -3443,6 +3443,7 @@ class ModelTesterMixin:
|
|||||||
"Data2VecAudioForSequenceClassification",
|
"Data2VecAudioForSequenceClassification",
|
||||||
"UniSpeechForSequenceClassification",
|
"UniSpeechForSequenceClassification",
|
||||||
"PvtForImageClassification",
|
"PvtForImageClassification",
|
||||||
|
"TimmWrapperForImageClassification",
|
||||||
]
|
]
|
||||||
special_param_names = [
|
special_param_names = [
|
||||||
r"^bit\.",
|
r"^bit\.",
|
||||||
@@ -3463,6 +3464,7 @@ class ModelTesterMixin:
|
|||||||
r"^swiftformer\.",
|
r"^swiftformer\.",
|
||||||
r"^swinv2\.",
|
r"^swinv2\.",
|
||||||
r"^transformers\.models\.swiftformer\.",
|
r"^transformers\.models\.swiftformer\.",
|
||||||
|
r"^timm_model\.",
|
||||||
r"^unispeech\.",
|
r"^unispeech\.",
|
||||||
r"^unispeech_sat\.",
|
r"^unispeech_sat\.",
|
||||||
r"^vision_model\.",
|
r"^vision_model\.",
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
|||||||
"RagConfig",
|
"RagConfig",
|
||||||
"SpeechEncoderDecoderConfig",
|
"SpeechEncoderDecoderConfig",
|
||||||
"TimmBackboneConfig",
|
"TimmBackboneConfig",
|
||||||
|
"TimmWrapperConfig",
|
||||||
"VisionEncoderDecoderConfig",
|
"VisionEncoderDecoderConfig",
|
||||||
"VisionTextDualEncoderConfig",
|
"VisionTextDualEncoderConfig",
|
||||||
"LlamaConfig",
|
"LlamaConfig",
|
||||||
|
|||||||
Reference in New Issue
Block a user