Add image classifier donut & update loss calculation for all swins (#37224)

* add classifier head to donut

* add to transformers __init__

* add to auto model

* fix typo

* add loss for image classification

* add checkpoint

* remove no needed import

* reoder import

* format

* consistency

* add test of classifier

* add doc

* try ignore

* update loss for all swin models
This commit is contained in:
AbdelKarim ELJANDOUBI
2025-04-10 15:00:42 +02:00
committed by GitHub
parent 5ae9b2cac0
commit 7ecc5b88c0
9 changed files with 177 additions and 48 deletions

View File

@@ -29,7 +29,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import DonutSwinModel
from transformers import DonutSwinForImageClassification, DonutSwinModel
class DonutSwinModelTester:
@@ -129,6 +129,24 @@ class DonutSwinModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = DonutSwinForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = DonutSwinForImageClassification(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -142,8 +160,12 @@ class DonutSwinModelTester:
@require_torch
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
all_model_classes = (DonutSwinModel, DonutSwinForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"image-feature-extraction": DonutSwinModel, "image-classification": DonutSwinForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = True
test_pruning = False
@@ -167,6 +189,10 @@ class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(reason="DonutSwin does not use inputs_embeds")
def test_inputs_embeds(self):
pass