Improve Swin for VisionEncoderDecoder (#16070)

* Add Swin2Bart test

* Fix swin tests

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-03-15 09:59:48 +01:00
committed by GitHub
parent 0a057201a9
commit a7aca42fc4
3 changed files with 94 additions and 6 deletions

View File

@@ -22,8 +22,10 @@ from datasets import load_dataset
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ..bart.test_modeling_bart import BartModelTester
from ..bert.test_modeling_bert import BertModelTester
from ..deit.test_modeling_deit import DeiTModelTester
from ..swin.test_modeling_swin import SwinModelTester
from ..test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester
from ..vit.test_modeling_vit import ViTModelTester
@@ -35,8 +37,10 @@ if is_torch_available():
from transformers import (
AutoTokenizer,
BartForCausalLM,
BertLMHeadModel,
DeiTModel,
SwinModel,
TrOCRForCausalLM,
VisionEncoderDecoderConfig,
VisionEncoderDecoderModel,
@@ -514,6 +518,90 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
}
@require_torch
class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = SwinModel(config).eval()
decoder_model = BartForCausalLM(decoder_config).eval()
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = SwinModelTester(self, batch_size=13, embed_dim=32)
model_tester_decoder = BartModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, _ = encoder_config_and_inputs
decoder_config, decoder_inputs_dict = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"pixel_values": pixel_values,
"decoder_config": decoder_config,
**decoder_inputs_dict,
}
def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels=None,
pixel_values=None,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
# in Swin, the seq_len equals:
seq_len = encoder_model.config.window_size**2
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads[0], seq_len, seq_len))
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
encoder_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, encoder_seq_len),
)
# there are no published pretrained BART-causal checkpoints for now
def test_real_model_save_load_from_pretrained(self):
pass
@require_torch
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):