Add image classifier donut & update loss calculation for all swins (#37224)
* add classifier head to donut * add to transformers __init__ * add to auto model * fix typo * add loss for image classification * add checkpoint * remove no needed import * reoder import * format * consistency * add test of classifier * add doc * try ignore * update loss for all swin models
This commit is contained in:
committed by
GitHub
parent
5ae9b2cac0
commit
7ecc5b88c0
@@ -29,7 +29,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import DonutSwinModel
|
||||
from transformers import DonutSwinForImageClassification, DonutSwinModel
|
||||
|
||||
|
||||
class DonutSwinModelTester:
|
||||
@@ -129,6 +129,24 @@ class DonutSwinModelTester:
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
model = DonutSwinForImageClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
# test greyscale images
|
||||
config.num_channels = 1
|
||||
model = DonutSwinForImageClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -142,8 +160,12 @@ class DonutSwinModelTester:
|
||||
|
||||
@require_torch
|
||||
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
|
||||
all_model_classes = (DonutSwinModel, DonutSwinForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"image-feature-extraction": DonutSwinModel, "image-classification": DonutSwinForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
fx_compatible = True
|
||||
|
||||
test_pruning = False
|
||||
@@ -167,6 +189,10 @@ class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="DonutSwin does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user