Add TimmWrapper (#34564)

* Add files

* Init

* Add TimmWrapperModel

* Fix up

* Some fixes

* Fix up

* Remove old file

* Sort out import orders

* Fix some model loading

* Compatible with pipeline and trainer

* Fix up

* Delete test_timm_model_1/config.json

* Remove accidentally commited files

* Delete src/transformers/models/modeling_timm_wrapper.py

* Remove empty imports; fix transformations applied

* Tidy up

* Add image classifcation model to special cases

* Create pretrained model; enable device_map='auto'

* Enable most tests; fix init order

* Sort imports

* [run-slow] timm_wrapper

* Pass num_classes into timm.create_model

* Remove train transforms from image processor

* Update timm creation with pretrained=False

* Fix gamma/beta issue for timm models

* Fixing gamma and beta renaming for timm models

* Simplify config and model creation

* Remove attn_implementation diff

* Fixup

* Docstrings

* Fix warning msg text according to test case

* Fix device_map auto

* Set dtype and device for pixel_values in forward

* Enable output hidden states

* Enable tests for hidden_states and model parallel

* Remove default scriptable arg

* Refactor inner model

* Update timm version

* Fix _find_mismatched_keys function

* Change inheritance for Classification model (fix weights loading with device_map)

* Minor bugfix

* Disable save pretrained for image processor

* Rename hook method for loaded keys correction

* Rename state dict keys on save, remove `timm_model` prefix, make checkpoint compatible with `timm`

* Managing num_labels <-> num_classes attributes

* Enable loading checkpoints in Trainer to resume training

* Update error message for output_hidden_states

* Add output hidden states test

* Decouple base and classification models

* Add more test cases

* Add save-load-to-timm test

* Fix test name

* Fixup

* Add do_pooling

* Add test for do_pooling

* Fix doc

* Add tests for TimmWrapperModel

* Add validation for `num_classes=0` in timm config + test for DINO checkpoint

* Adjust atol for test

* Fix docs

* dev-ci

* dev-ci

* Add tests for image processor

* Update docs

* Update init to new format

* Update docs in configuration

* Fix some docs in image processor

* Improve docs for modeling

* fix for is_timm_checkpoint

* Update code examples

* Fix header

* Fix typehint

* Increase tolerance a bit

* Fix Path

* Fixing model parallel tests

* Disable "parallel" tests

* Add comment for metadata

* Refactor AutoImageProcessor for timm wrapper loading

* Remove custom test_model_outputs_equivalence

* Add require_timm decorator

* Fix comment

* Make image processor work with older timm versions and tensor input

* Save config instead of whole model in image processor tests

* Add docstring for `image_processor_filename`

* Sanitize kwargs for timm image processor

* Fix doc style

* Update check for tensor input

* Update normalize

* Remove _load_timm_model function

---------

Co-authored-by: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Pavel Iakubovskii
2024-12-11 13:40:30 +01:00
committed by GitHub
parent bcc50cc7ce
commit 5fcf6286bf
25 changed files with 1432 additions and 129 deletions

View File

@@ -42,6 +42,7 @@ from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
HfArgumentParser,
TimmWrapperImageProcessor,
Trainer,
TrainingArguments,
set_seed,
@@ -329,31 +330,36 @@ def main():
)
# Define torchvision transforms to be applied to each image.
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
if isinstance(image_processor, TimmWrapperImageProcessor):
_train_transforms = image_processor.train_transforms
_val_transforms = image_processor.val_transforms
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
_train_transforms = Compose(
[
RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
_val_transforms = Compose(
[
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
]
)
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
# Create normalization transform
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std"):
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
else:
normalize = Lambda(lambda x: x)
_train_transforms = Compose(
[
RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
_val_transforms = Compose(
[
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
]
)
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""