Add image classifier donut & update loss calculation for all swins (#37224)
* add classifier head to donut * add to transformers __init__ * add to auto model * fix typo * add loss for image classification * add checkpoint * remove no needed import * reoder import * format * consistency * add test of classifier * add doc * try ignore * update loss for all swin models
This commit is contained in:
committed by
GitHub
parent
5ae9b2cac0
commit
7ecc5b88c0
@@ -226,3 +226,8 @@ print(answer)
|
|||||||
|
|
||||||
[[autodoc]] DonutSwinModel
|
[[autodoc]] DonutSwinModel
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## DonutSwinForImageClassification
|
||||||
|
|
||||||
|
[[autodoc]] transformers.DonutSwinForImageClassification
|
||||||
|
- forward
|
||||||
@@ -2304,6 +2304,7 @@ else:
|
|||||||
)
|
)
|
||||||
_import_structure["models.donut"].extend(
|
_import_structure["models.donut"].extend(
|
||||||
[
|
[
|
||||||
|
"DonutSwinForImageClassification",
|
||||||
"DonutSwinModel",
|
"DonutSwinModel",
|
||||||
"DonutSwinPreTrainedModel",
|
"DonutSwinPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -7457,6 +7458,7 @@ if TYPE_CHECKING:
|
|||||||
DistilBertPreTrainedModel,
|
DistilBertPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.donut import (
|
from .models.donut import (
|
||||||
|
DonutSwinForImageClassification,
|
||||||
DonutSwinModel,
|
DonutSwinModel,
|
||||||
DonutSwinPreTrainedModel,
|
DonutSwinPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -145,6 +145,7 @@ LOSS_MAPPING = {
|
|||||||
"ForMaskedLM": ForMaskedLMLoss,
|
"ForMaskedLM": ForMaskedLMLoss,
|
||||||
"ForQuestionAnswering": ForQuestionAnsweringLoss,
|
"ForQuestionAnswering": ForQuestionAnsweringLoss,
|
||||||
"ForSequenceClassification": ForSequenceClassificationLoss,
|
"ForSequenceClassification": ForSequenceClassificationLoss,
|
||||||
|
"ForImageClassification": ForSequenceClassificationLoss,
|
||||||
"ForTokenClassification": ForTokenClassification,
|
"ForTokenClassification": ForTokenClassification,
|
||||||
"ForSegmentation": ForSegmentationLoss,
|
"ForSegmentation": ForSegmentationLoss,
|
||||||
"ForObjectDetection": ForObjectDetectionLoss,
|
"ForObjectDetection": ForObjectDetectionLoss,
|
||||||
|
|||||||
@@ -707,6 +707,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("dinat", "DinatForImageClassification"),
|
("dinat", "DinatForImageClassification"),
|
||||||
("dinov2", "Dinov2ForImageClassification"),
|
("dinov2", "Dinov2ForImageClassification"),
|
||||||
("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
|
("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
|
||||||
|
("donut-swin", "DonutSwinForImageClassification"),
|
||||||
(
|
(
|
||||||
"efficientformer",
|
"efficientformer",
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ _CONFIG_FOR_DOC = "DonutSwinConfig"
|
|||||||
_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
|
_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
|
||||||
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
|
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
|
||||||
|
|
||||||
|
# Image classification docstring
|
||||||
|
_IMAGE_CLASS_CHECKPOINT = "eljandoubi/donut-base-encoder"
|
||||||
|
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
|
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
|
||||||
@@ -121,6 +125,43 @@ class DonutSwinModelOutput(ModelOutput):
|
|||||||
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->DonutSwin
|
||||||
|
class DonutSwinImageClassifierOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
DonutSwin outputs for image classification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||||
|
Classification (or regression if config.num_labels==1) loss.
|
||||||
|
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||||||
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||||
|
shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||||||
|
shape `(batch_size, hidden_size, height, width)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
||||||
|
include the spatial dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.swin.modeling_swin.window_partition
|
# Copied from transformers.models.swin.modeling_swin.window_partition
|
||||||
def window_partition(input_feature, window_size):
|
def window_partition(input_feature, window_size):
|
||||||
"""
|
"""
|
||||||
@@ -845,7 +886,7 @@ class DonutSwinEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
|
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin,swin->donut
|
||||||
class DonutSwinPreTrainedModel(PreTrainedModel):
|
class DonutSwinPreTrainedModel(PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@@ -853,7 +894,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = DonutSwinConfig
|
config_class = DonutSwinConfig
|
||||||
base_model_prefix = "swin"
|
base_model_prefix = "donut"
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["DonutSwinStage"]
|
_no_split_modules = ["DonutSwinStage"]
|
||||||
@@ -1015,4 +1056,90 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel"]
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
||||||
|
the [CLS] token) e.g. for ImageNet.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by
|
||||||
|
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
||||||
|
position embeddings to the higher resolution.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
""",
|
||||||
|
SWIN_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with Swin->DonutSwin,swin->donut
|
||||||
|
class DonutSwinForImageClassification(DonutSwinPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.donut = DonutSwinModel(config)
|
||||||
|
|
||||||
|
# Classifier head
|
||||||
|
self.classifier = (
|
||||||
|
nn.Linear(self.donut.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||||
|
output_type=DonutSwinImageClassifierOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, DonutSwinImageClassifierOutput]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.donut(
|
||||||
|
pixel_values,
|
||||||
|
head_mask=head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return DonutSwinImageClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel", "DonutSwinForImageClassification"]
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BackboneOutput
|
from ...modeling_outputs import BackboneOutput
|
||||||
@@ -1285,26 +1284,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.config.problem_type is None:
|
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
|
||||||
if self.num_labels == 1:
|
|
||||||
self.config.problem_type = "regression"
|
|
||||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
||||||
self.config.problem_type = "single_label_classification"
|
|
||||||
else:
|
|
||||||
self.config.problem_type = "multi_label_classification"
|
|
||||||
|
|
||||||
if self.config.problem_type == "regression":
|
|
||||||
loss_fct = MSELoss()
|
|
||||||
if self.num_labels == 1:
|
|
||||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits, labels)
|
|
||||||
elif self.config.problem_type == "single_label_classification":
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
elif self.config.problem_type == "multi_label_classification":
|
|
||||||
loss_fct = BCEWithLogitsLoss()
|
|
||||||
loss = loss_fct(logits, labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_outputs import BackboneOutput
|
from ...modeling_outputs import BackboneOutput
|
||||||
@@ -1339,26 +1338,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
if self.config.problem_type is None:
|
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
|
||||||
if self.num_labels == 1:
|
|
||||||
self.config.problem_type = "regression"
|
|
||||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
||||||
self.config.problem_type = "single_label_classification"
|
|
||||||
else:
|
|
||||||
self.config.problem_type = "multi_label_classification"
|
|
||||||
|
|
||||||
if self.config.problem_type == "regression":
|
|
||||||
loss_fct = MSELoss()
|
|
||||||
if self.num_labels == 1:
|
|
||||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
||||||
else:
|
|
||||||
loss = loss_fct(logits, labels)
|
|
||||||
elif self.config.problem_type == "single_label_classification":
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
||||||
elif self.config.problem_type == "multi_label_classification":
|
|
||||||
loss_fct = BCEWithLogitsLoss()
|
|
||||||
loss = loss_fct(logits, labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[2:]
|
||||||
|
|||||||
@@ -3829,6 +3829,13 @@ class DistilBertPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DonutSwinForImageClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class DonutSwinModel(metaclass=DummyObject):
|
class DonutSwinModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import DonutSwinModel
|
from transformers import DonutSwinForImageClassification, DonutSwinModel
|
||||||
|
|
||||||
|
|
||||||
class DonutSwinModelTester:
|
class DonutSwinModelTester:
|
||||||
@@ -129,6 +129,24 @@ class DonutSwinModelTester:
|
|||||||
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||||
|
|
||||||
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
|
config.num_labels = self.type_sequence_label_size
|
||||||
|
model = DonutSwinForImageClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values, labels=labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
|
# test greyscale images
|
||||||
|
config.num_channels = 1
|
||||||
|
model = DonutSwinForImageClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -142,8 +160,12 @@ class DonutSwinModelTester:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
|
all_model_classes = (DonutSwinModel, DonutSwinForImageClassification) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
|
pipeline_model_mapping = (
|
||||||
|
{"image-feature-extraction": DonutSwinModel, "image-classification": DonutSwinForImageClassification}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
fx_compatible = True
|
fx_compatible = True
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
@@ -167,6 +189,10 @@ class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_image_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skip(reason="DonutSwin does not use inputs_embeds")
|
@unittest.skip(reason="DonutSwin does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user