[Hotfix] Fix Swin model outputs (#15414)
* Fix Swin model outputs * Rename pooler
This commit is contained in:
@@ -21,11 +21,11 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
|
||||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_swin import SwinConfig
|
from .configuration_swin import SwinConfig
|
||||||
@@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module):
|
|||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
pixel_values = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||||
return pixel_values
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class SwinPatchMerging(nn.Module):
|
class SwinPatchMerging(nn.Module):
|
||||||
@@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r"""
|
|||||||
SWIN_START_DOCSTRING,
|
SWIN_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class SwinModel(SwinPreTrainedModel):
|
class SwinModel(SwinPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config, add_pooling_layer=True):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.num_layers = len(config.depths)
|
self.num_layers = len(config.depths)
|
||||||
@@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel):
|
|||||||
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
|
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
|
||||||
|
|
||||||
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
|
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
|
||||||
self.pool = nn.AdaptiveAvgPool1d(1)
|
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel):
|
|||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values=None,
|
pixel_values=None,
|
||||||
@@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel):
|
|||||||
|
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
sequence_output = self.layernorm(sequence_output)
|
sequence_output = self.layernorm(sequence_output)
|
||||||
sequence_output = self.pool(sequence_output.transpose(1, 2))
|
|
||||||
sequence_output = torch.flatten(sequence_output, 1)
|
pooled_output = None
|
||||||
|
if self.pooler is not None:
|
||||||
|
pooled_output = self.pooler(sequence_output.transpose(1, 2))
|
||||||
|
pooled_output = torch.flatten(pooled_output, 1)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (sequence_output,) + encoder_outputs[1:]
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||||
|
|
||||||
return BaseModelOutput(
|
return BaseModelOutputWithPooling(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
|
pooler_output=pooled_output,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs.attentions,
|
attentions=encoder_outputs.attentions,
|
||||||
)
|
)
|
||||||
@@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
logits = self.classifier(sequence_output)
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
if self.config.problem_type is None:
|
||||||
if self.num_labels == 1:
|
if self.num_labels == 1:
|
||||||
# We are doing regression
|
self.config.problem_type = "regression"
|
||||||
loss_fct = MSELoss()
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
self.config.problem_type = "single_label_classification"
|
||||||
else:
|
else:
|
||||||
|
self.config.problem_type = "multi_label_classification"
|
||||||
|
|
||||||
|
if self.config.problem_type == "regression":
|
||||||
|
loss_fct = MSELoss()
|
||||||
|
if self.num_labels == 1:
|
||||||
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||||
|
else:
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
elif self.config.problem_type == "single_label_classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
elif self.config.problem_type == "multi_label_classification":
|
||||||
|
loss_fct = BCEWithLogitsLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[2:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return SequenceClassifierOutput(
|
return SequenceClassifierOutput(
|
||||||
|
|||||||
@@ -137,9 +137,11 @@ class SwinModelTester:
|
|||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
|
|
||||||
num_features = int(config.embed_dim * 2 ** (len(config.depths) - 1))
|
# since the model we're testing only consists of a single layer, expected_seq_len = number of patches
|
||||||
|
expected_seq_len = (config.image_size // config.patch_size) ** 2
|
||||||
|
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
|
||||||
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_features))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||||
|
|
||||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||||
config.num_labels = self.type_sequence_label_size
|
config.num_labels = self.type_sequence_label_size
|
||||||
@@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
|
|||||||
expected_shape = torch.Size((1, 1000))
|
expected_shape = torch.Size((1, 1000))
|
||||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|
||||||
expected_slice = torch.tensor([-0.2952, -0.4777, 0.2025]).to(torch_device)
|
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||||
|
|||||||
Reference in New Issue
Block a user