Attn implementation for composite models (#32238)
* first try * codestyle * idefics2 is happy * [run-slow] llava, llava_next, video_llava, vipllava, llava_next_video, idefics, idefics2, kosmos2, fuyu, blip, blip_2, instructblip, instructblipvideo, paligemma * fix-copies * [run-slow] llava, llava_next, video_llava, vipllava, llava_next_video, idefics, idefics2, kosmos2, fuyu, blip, blip_2, instructblip, instructblipvideo * blip-2 needs to init vision from config * when was this removed O_o * minor fix * tests * this way? * tests * model-agnostic code * codestyle * add tests for idefics * modify general test for VLMs * no generation test for vlm yet! * no generation test here also * wanr in VIT-SDPA if output attn * add more tests * user can pass dict as attn impl * repo consistency * update * muicgen * no prints * forgot speech enc-dec and clip * how many composite models we have? * musicgen meelody is same as mudicgen * +siglip * fix tests + add some more * remove idefics custom overriden code * make idefics2 automappable * nits * skip tests * doctests * Update src/transformers/models/idefics2/configuration_idefics2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/clip/test_modeling_clip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/idefics2/test_modeling_idefics2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/idefics2/test_modeling_idefics2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * major update, no need for automap * clean up * add FA2 test * more tests * style * skip tests * why did these started failing now? * no attributes for FA2 needed * one tiny test * address comment about FA2 false warning * style * add new models and resolve conflicts * fix copies * let it be this way for now, come back tomorrow to review * some more fixes * update * more updates * update * fix copies * style and tests * another big update * fix tests * fix tests * update * another update * fix tests * fix copies * fix tests --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
32590b5ecb
commit
21d5025826
@@ -27,17 +27,24 @@ from transformers.testing_utils import (
|
||||
require_nltk,
|
||||
require_sentencepiece,
|
||||
require_torch,
|
||||
require_torch_sdpa,
|
||||
require_vision,
|
||||
slow,
|
||||
to_2tuple,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.utils import (
|
||||
cached_property,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..bart.test_modeling_bart import BartModelTester
|
||||
from ..bert.test_modeling_bert import BertModelTester
|
||||
from ..deit.test_modeling_deit import DeiTModelTester
|
||||
from ..donut.test_modeling_donut_swin import DonutSwinModelTester
|
||||
from ..gpt2.test_modeling_gpt2 import GPT2ModelTester
|
||||
from ..layoutlmv3.test_modeling_layoutlmv3 import LayoutLMv3ModelTester
|
||||
from ..swin.test_modeling_swin import SwinModelTester
|
||||
from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester
|
||||
@@ -53,6 +60,8 @@ if is_torch_available():
|
||||
BartForCausalLM,
|
||||
BertLMHeadModel,
|
||||
DeiTModel,
|
||||
DonutSwinModel,
|
||||
GPT2LMHeadModel,
|
||||
LayoutLMv3Model,
|
||||
SwinModel,
|
||||
TrOCRForCausalLM,
|
||||
@@ -72,6 +81,8 @@ if is_vision_available():
|
||||
|
||||
@require_torch
|
||||
class EncoderDecoderMixin:
|
||||
supports_sdpa = False
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
pass
|
||||
|
||||
@@ -374,6 +385,69 @@ class EncoderDecoderMixin:
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
if not self.supports_sdpa:
|
||||
self.skipTest("SDPA is not supported")
|
||||
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
encoder_config, decoder_config = inputs_dict["config"], inputs_dict["decoder_config"]
|
||||
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
|
||||
encoder_config=encoder_config, decoder_config=decoder_config
|
||||
)
|
||||
model = VisionEncoderDecoderModel(config=config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_sdpa = VisionEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
# see https://github.com/huggingface/transformers/pull/32238
|
||||
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
|
||||
encoder_attn = "sdpa" if model.encoder._supports_sdpa else "eager"
|
||||
decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager"
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_sdpa.encoder.config._attn_implementation == encoder_attn)
|
||||
self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn)
|
||||
|
||||
# Also test that nothing break if we request SDPA explicitly, when both sub-parts support it.
|
||||
# If the model supports sdpa (i.e. all of sub-models supports it) we'll dispatch safely
|
||||
# Otherwise we should raise error that SDPA is not supported, as some of the sub-models doesn't support
|
||||
if encoder_attn == "sdpa" and decoder_attn == "sdpa":
|
||||
model_sdpa_explicit = VisionEncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="sdpa")
|
||||
model_sdpa_explicit = model_sdpa_explicit.eval().to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa_explicit.config._attn_implementation == "sdpa")
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
model_sdpa_explicit = VisionEncoderDecoderModel.from_pretrained(
|
||||
tmpdirname, attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
model_eager = VisionEncoderDecoderModel.from_pretrained(
|
||||
tmpdirname,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.encoder.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa:
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
|
||||
@require_torch
|
||||
class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
@@ -497,6 +571,8 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
supports_sdpa = True # one submodel support SDPA
|
||||
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"hf-internal-testing/tiny-random-vit", "hf-internal-testing/tiny-bert"
|
||||
@@ -649,6 +725,8 @@ class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
supports_sdpa = True # one submodel support SDPA
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = ViTModel(config).eval()
|
||||
decoder_model = TrOCRForCausalLM(decoder_config).eval()
|
||||
@@ -804,6 +882,240 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
|
||||
supports_sdpa = True # both submodels support SDPA
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = ViTModel(config).eval()
|
||||
decoder_model = GPT2LMHeadModel(decoder_config).eval()
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = ViTModelTester(self, batch_size=13)
|
||||
model_tester_decoder = GPT2ModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
decoder_head_mask,
|
||||
decoder_token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
# disable cache for now
|
||||
decoder_config.use_cache = False
|
||||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"labels": decoder_input_ids,
|
||||
}
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
seq_len = (encoder_model.config.image_size // encoder_model.config.patch_size) ** 2 + 1
|
||||
|
||||
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||
num_decoder_layers = (
|
||||
decoder_config.num_decoder_layers
|
||||
if hasattr(decoder_config, "num_decoder_layers")
|
||||
else decoder_config.num_hidden_layers
|
||||
)
|
||||
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||
|
||||
self.assertEqual(
|
||||
decoder_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||
|
||||
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
|
||||
self.assertEqual(
|
||||
cross_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len), # 4 6 16
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
|
||||
# Generate until max length
|
||||
if hasattr(enc_dec_model.config, "eos_token_id"):
|
||||
enc_dec_model.config.eos_token_id = None
|
||||
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||
enc_dec_model.config.decoder.eos_token_id = None
|
||||
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
|
||||
enc_dec_model.generation_config.eos_token_id = None
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
generated_output = enc_dec_model.generate(
|
||||
pixel_values=pixel_values,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
@unittest.skip(reason="VIT2GPT2 also has an integration test for testinf save-load")
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
|
||||
supports_sdpa = True # one submodel (GPT2) support SDPA
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = DonutSwinModel(config).eval()
|
||||
decoder_model = GPT2LMHeadModel(decoder_config).eval()
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = DonutSwinModelTester(self, batch_size=13)
|
||||
model_tester_decoder = GPT2ModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
decoder_head_mask,
|
||||
decoder_token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
# disable cache for now
|
||||
decoder_config.use_cache = False
|
||||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"labels": decoder_input_ids,
|
||||
}
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
seq_len = encoder_model.config.image_size // encoder_model.config.patch_size
|
||||
|
||||
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||
num_decoder_layers = (
|
||||
decoder_config.num_decoder_layers
|
||||
if hasattr(decoder_config, "num_decoder_layers")
|
||||
else decoder_config.num_hidden_layers
|
||||
)
|
||||
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||
|
||||
self.assertEqual(
|
||||
decoder_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||
|
||||
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
|
||||
self.assertEqual(
|
||||
cross_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len), # 4 6 16
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
|
||||
# Generate until max length
|
||||
if hasattr(enc_dec_model.config, "eos_token_id"):
|
||||
enc_dec_model.config.eos_token_id = None
|
||||
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
|
||||
enc_dec_model.config.decoder.eos_token_id = None
|
||||
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
|
||||
enc_dec_model.generation_config.eos_token_id = None
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
generated_output = enc_dec_model.generate(
|
||||
pixel_values=pixel_values,
|
||||
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
@unittest.skip(reason="Donut has an Integration test for that")
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class TrOCRModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user