#26566 swin2 sr allow in out channels (#26568)

* feat: close #26566, changed model & config files to accept arbitary in and out channels

* updated docstrings

* fix: linter error

* fix: update Copy docstrings

* fix: linter update

* fix: rename num_channels_in to num_channels to prevent breaking changes

* fix: make num_channels_out None per default

* Update src/transformers/models/swin2sr/configuration_swin2sr.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix: update tests to include num_channels_out

* fix:linter

* fix: remove normalization with precomputed rgb values when #input_channels!=#output_channels

---------

Co-authored-by: marvingabler <marvingabler@outlook.de>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Marvin Gabler
2023-10-05 15:20:38 +02:00
committed by GitHub
parent e6d250e4cd
commit 0a3b9d02fe
3 changed files with 17 additions and 7 deletions

View File

@@ -46,6 +46,7 @@ class Swin2SRModelTester:
image_size=32,
patch_size=1,
num_channels=3,
num_channels_out=1,
embed_dim=16,
depths=[1, 2, 1],
num_heads=[2, 2, 4],
@@ -70,6 +71,7 @@ class Swin2SRModelTester:
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_channels_out = num_channels_out
self.embed_dim = embed_dim
self.depths = depths
self.num_heads = num_heads
@@ -110,6 +112,7 @@ class Swin2SRModelTester:
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
num_channels_out=self.num_channels_out,
embed_dim=self.embed_dim,
depths=self.depths,
num_heads=self.num_heads,
@@ -145,7 +148,8 @@ class Swin2SRModelTester:
expected_image_size = self.image_size * self.upscale
self.parent.assertEqual(
result.reconstruction.shape, (self.batch_size, self.num_channels, expected_image_size, expected_image_size)
result.reconstruction.shape,
(self.batch_size, self.num_channels_out, expected_image_size, expected_image_size),
)
def prepare_config_and_inputs_for_common(self):