From d4b3e56d6443aff5148419854f9d4cd45d2db915 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Mon, 31 Jan 2022 16:32:14 +0100 Subject: [PATCH] [Hotfix] Fix Swin model outputs (#15414) * Fix Swin model outputs * Rename pooler --- src/transformers/models/swin/modeling_swin.py | 53 ++++++++++++------- tests/test_modeling_swin.py | 8 +-- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index d3d8afef07..ec80d83511 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -21,11 +21,11 @@ import math import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings -from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from .configuration_swin import SwinConfig @@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module): self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values): - pixel_values = self.projection(pixel_values).flatten(2).transpose(1, 2) - return pixel_values + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings class SwinPatchMerging(nn.Module): @@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r""" SWIN_START_DOCSTRING, ) class SwinModel(SwinPreTrainedModel): - def __init__(self, config): + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config self.num_layers = len(config.depths) @@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel): self.encoder = SwinEncoder(config, self.embeddings.patch_grid) self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) - self.pool = nn.AdaptiveAvgPool1d(1) + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None # Initialize weights and apply final processing self.post_init() @@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel): self.encoder.layer[layer].attention.prune_heads(heads) @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values=None, @@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel): sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) - sequence_output = self.pool(sequence_output.transpose(1, 2)) - sequence_output = torch.flatten(sequence_output, 1) + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) if not return_dict: - return (sequence_output,) + encoder_outputs[1:] + return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutput( + return BaseModelOutputWithPooling( last_hidden_state=sequence_output, + pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel): return_dict=return_dict, ) - sequence_output = outputs[0] + pooled_output = outputs[1] - logits = self.classifier(sequence_output) + logits = self.classifier(pooled_output) loss = None if labels is not None: - if self.num_labels == 1: - # We are doing regression + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: - output = (logits,) + outputs[1:] + output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( diff --git a/tests/test_modeling_swin.py b/tests/test_modeling_swin.py index ce8db379b1..29eddb7f7c 100644 --- a/tests/test_modeling_swin.py +++ b/tests/test_modeling_swin.py @@ -137,9 +137,11 @@ class SwinModelTester: model.eval() result = model(pixel_values) - num_features = int(config.embed_dim * 2 ** (len(config.depths) - 1)) + # since the model we're testing only consists of a single layer, expected_seq_len = number of patches + expected_seq_len = (config.image_size // config.patch_size) ** 2 + expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1)) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_features)) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size @@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-0.2952, -0.4777, 0.2025]).to(torch_device) + expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))