[Hotfix] Fix Swin model outputs (#15414)

* Fix Swin model outputs

* Rename pooler
This commit is contained in:
NielsRogge
2022-01-31 16:32:14 +01:00
committed by GitHub
parent 38dfb40ae3
commit d4b3e56d64
2 changed files with 40 additions and 21 deletions

View File

@@ -21,11 +21,11 @@ import math
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
from .configuration_swin import SwinConfig from .configuration_swin import SwinConfig
@@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values): def forward(self, pixel_values):
pixel_values = self.projection(pixel_values).flatten(2).transpose(1, 2) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return pixel_values return embeddings
class SwinPatchMerging(nn.Module): class SwinPatchMerging(nn.Module):
@@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r"""
SWIN_START_DOCSTRING, SWIN_START_DOCSTRING,
) )
class SwinModel(SwinPreTrainedModel): class SwinModel(SwinPreTrainedModel):
def __init__(self, config): def __init__(self, config, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.num_layers = len(config.depths) self.num_layers = len(config.depths)
@@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel):
self.encoder = SwinEncoder(config, self.embeddings.patch_grid) self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
self.pool = nn.AdaptiveAvgPool1d(1) self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
@@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel):
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output) sequence_output = self.layernorm(sequence_output)
sequence_output = self.pool(sequence_output.transpose(1, 2))
sequence_output = torch.flatten(sequence_output, 1) pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose(1, 2))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict: if not return_dict:
return (sequence_output,) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutput( return BaseModelOutputWithPooling(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
@@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
) )
sequence_output = outputs[0] pooled_output = outputs[1]
logits = self.classifier(sequence_output) logits = self.classifier(pooled_output)
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(

View File

@@ -137,9 +137,11 @@ class SwinModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
num_features = int(config.embed_dim * 2 ** (len(config.depths) - 1)) # since the model we're testing only consists of a single layer, expected_seq_len = number of patches
expected_seq_len = (config.image_size // config.patch_size) ** 2
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_features)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
@@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000)) expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.2952, -0.4777, 0.2025]).to(torch_device) expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))