Add image classifier donut & update loss calculation for all swins (#37224)

* add classifier head to donut

* add to transformers __init__

* add to auto model

* fix typo

* add loss for image classification

* add checkpoint

* remove no needed import

* reoder import

* format

* consistency

* add test of classifier

* add doc

* try ignore

* update loss for all swin models
This commit is contained in:
AbdelKarim ELJANDOUBI
2025-04-10 15:00:42 +02:00
committed by GitHub
parent 5ae9b2cac0
commit 7ecc5b88c0
9 changed files with 177 additions and 48 deletions

View File

@@ -226,3 +226,8 @@ print(answer)
[[autodoc]] DonutSwinModel [[autodoc]] DonutSwinModel
- forward - forward
## DonutSwinForImageClassification
[[autodoc]] transformers.DonutSwinForImageClassification
- forward

View File

@@ -2304,6 +2304,7 @@ else:
) )
_import_structure["models.donut"].extend( _import_structure["models.donut"].extend(
[ [
"DonutSwinForImageClassification",
"DonutSwinModel", "DonutSwinModel",
"DonutSwinPreTrainedModel", "DonutSwinPreTrainedModel",
] ]
@@ -7457,6 +7458,7 @@ if TYPE_CHECKING:
DistilBertPreTrainedModel, DistilBertPreTrainedModel,
) )
from .models.donut import ( from .models.donut import (
DonutSwinForImageClassification,
DonutSwinModel, DonutSwinModel,
DonutSwinPreTrainedModel, DonutSwinPreTrainedModel,
) )

View File

@@ -145,6 +145,7 @@ LOSS_MAPPING = {
"ForMaskedLM": ForMaskedLMLoss, "ForMaskedLM": ForMaskedLMLoss,
"ForQuestionAnswering": ForQuestionAnsweringLoss, "ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss, "ForSequenceClassification": ForSequenceClassificationLoss,
"ForImageClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification, "ForTokenClassification": ForTokenClassification,
"ForSegmentation": ForSegmentationLoss, "ForSegmentation": ForSegmentationLoss,
"ForObjectDetection": ForObjectDetectionLoss, "ForObjectDetection": ForObjectDetectionLoss,

View File

@@ -707,6 +707,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("dinat", "DinatForImageClassification"), ("dinat", "DinatForImageClassification"),
("dinov2", "Dinov2ForImageClassification"), ("dinov2", "Dinov2ForImageClassification"),
("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"), ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
("donut-swin", "DonutSwinForImageClassification"),
( (
"efficientformer", "efficientformer",
( (

View File

@@ -49,6 +49,10 @@ _CONFIG_FOR_DOC = "DonutSwinConfig"
_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base" _CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768] _EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "eljandoubi/donut-base-encoder"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
@dataclass @dataclass
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
@@ -121,6 +125,43 @@ class DonutSwinModelOutput(ModelOutput):
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->DonutSwin
class DonutSwinImageClassifierOutput(ModelOutput):
"""
DonutSwin outputs for image classification.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
# Copied from transformers.models.swin.modeling_swin.window_partition # Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size): def window_partition(input_feature, window_size):
""" """
@@ -845,7 +886,7 @@ class DonutSwinEncoder(nn.Module):
) )
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin # Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin,swin->donut
class DonutSwinPreTrainedModel(PreTrainedModel): class DonutSwinPreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -853,7 +894,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
""" """
config_class = DonutSwinConfig config_class = DonutSwinConfig
base_model_prefix = "swin" base_model_prefix = "donut"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["DonutSwinStage"] _no_split_modules = ["DonutSwinStage"]
@@ -1015,4 +1056,90 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
) )
__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel"] @add_start_docstrings(
"""
DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""",
SWIN_START_DOCSTRING,
)
# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with Swin->DonutSwin,swin->donut
class DonutSwinForImageClassification(DonutSwinPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.donut = DonutSwinModel(config)
# Classifier head
self.classifier = (
nn.Linear(self.donut.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=DonutSwinImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, DonutSwinImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.donut(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
pooled_output = outputs[1]
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return DonutSwinImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel", "DonutSwinForImageClassification"]

View File

@@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
@@ -1285,26 +1284,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.problem_type is None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]

View File

@@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
@@ -1339,26 +1338,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.problem_type is None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]

View File

@@ -3829,6 +3829,13 @@ class DistilBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class DonutSwinForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DonutSwinModel(metaclass=DummyObject): class DonutSwinModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]

View File

@@ -29,7 +29,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import DonutSwinModel from transformers import DonutSwinForImageClassification, DonutSwinModel
class DonutSwinModelTester: class DonutSwinModelTester:
@@ -129,6 +129,24 @@ class DonutSwinModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = DonutSwinForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = DonutSwinForImageClassification(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
@@ -142,8 +160,12 @@ class DonutSwinModelTester:
@require_torch @require_torch
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DonutSwinModel,) if is_torch_available() else () all_model_classes = (DonutSwinModel, DonutSwinForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {} pipeline_model_mapping = (
{"image-feature-extraction": DonutSwinModel, "image-classification": DonutSwinForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
@@ -167,6 +189,10 @@ class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(reason="DonutSwin does not use inputs_embeds") @unittest.skip(reason="DonutSwin does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass