Add TimmWrapper (#34564)
* Add files * Init * Add TimmWrapperModel * Fix up * Some fixes * Fix up * Remove old file * Sort out import orders * Fix some model loading * Compatible with pipeline and trainer * Fix up * Delete test_timm_model_1/config.json * Remove accidentally commited files * Delete src/transformers/models/modeling_timm_wrapper.py * Remove empty imports; fix transformations applied * Tidy up * Add image classifcation model to special cases * Create pretrained model; enable device_map='auto' * Enable most tests; fix init order * Sort imports * [run-slow] timm_wrapper * Pass num_classes into timm.create_model * Remove train transforms from image processor * Update timm creation with pretrained=False * Fix gamma/beta issue for timm models * Fixing gamma and beta renaming for timm models * Simplify config and model creation * Remove attn_implementation diff * Fixup * Docstrings * Fix warning msg text according to test case * Fix device_map auto * Set dtype and device for pixel_values in forward * Enable output hidden states * Enable tests for hidden_states and model parallel * Remove default scriptable arg * Refactor inner model * Update timm version * Fix _find_mismatched_keys function * Change inheritance for Classification model (fix weights loading with device_map) * Minor bugfix * Disable save pretrained for image processor * Rename hook method for loaded keys correction * Rename state dict keys on save, remove `timm_model` prefix, make checkpoint compatible with `timm` * Managing num_labels <-> num_classes attributes * Enable loading checkpoints in Trainer to resume training * Update error message for output_hidden_states * Add output hidden states test * Decouple base and classification models * Add more test cases * Add save-load-to-timm test * Fix test name * Fixup * Add do_pooling * Add test for do_pooling * Fix doc * Add tests for TimmWrapperModel * Add validation for `num_classes=0` in timm config + test for DINO checkpoint * Adjust atol for test * Fix docs * dev-ci * dev-ci * Add tests for image processor * Update docs * Update init to new format * Update docs in configuration * Fix some docs in image processor * Improve docs for modeling * fix for is_timm_checkpoint * Update code examples * Fix header * Fix typehint * Increase tolerance a bit * Fix Path * Fixing model parallel tests * Disable "parallel" tests * Add comment for metadata * Refactor AutoImageProcessor for timm wrapper loading * Remove custom test_model_outputs_equivalence * Add require_timm decorator * Fix comment * Make image processor work with older timm versions and tensor input * Save config instead of whole model in image processor tests * Add docstring for `image_processor_filename` * Sanitize kwargs for timm image processor * Fix doc style * Update check for tensor input * Update normalize * Remove _load_timm_model function --------- Co-authored-by: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bcc50cc7ce
commit
5fcf6286bf
@@ -42,6 +42,7 @@ from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoModelForImageClassification,
|
||||
HfArgumentParser,
|
||||
TimmWrapperImageProcessor,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
@@ -329,31 +330,36 @@ def main():
|
||||
)
|
||||
|
||||
# Define torchvision transforms to be applied to each image.
|
||||
if "shortest_edge" in image_processor.size:
|
||||
size = image_processor.size["shortest_edge"]
|
||||
if isinstance(image_processor, TimmWrapperImageProcessor):
|
||||
_train_transforms = image_processor.train_transforms
|
||||
_val_transforms = image_processor.val_transforms
|
||||
else:
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
normalize = (
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
|
||||
else Lambda(lambda x: x)
|
||||
)
|
||||
_train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(size),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
_val_transforms = Compose(
|
||||
[
|
||||
Resize(size),
|
||||
CenterCrop(size),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
if "shortest_edge" in image_processor.size:
|
||||
size = image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
|
||||
# Create normalization transform
|
||||
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std"):
|
||||
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
else:
|
||||
normalize = Lambda(lambda x: x)
|
||||
_train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(size),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
_val_transforms = Compose(
|
||||
[
|
||||
Resize(size),
|
||||
CenterCrop(size),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
|
||||
def train_transforms(example_batch):
|
||||
"""Apply _train_transforms across a batch."""
|
||||
|
||||
Reference in New Issue
Block a user