Attn implementation for composite models (#32238)

* first try

* codestyle

* idefics2 is happy

* [run-slow] llava, llava_next, video_llava, vipllava, llava_next_video, idefics, idefics2, kosmos2, fuyu, blip, blip_2, instructblip, instructblipvideo, paligemma

* fix-copies

* [run-slow] llava, llava_next, video_llava, vipllava, llava_next_video, idefics, idefics2, kosmos2, fuyu, blip, blip_2, instructblip, instructblipvideo

* blip-2 needs to init vision from config

* when was this removed O_o

* minor fix

* tests

* this way?

* tests

* model-agnostic code

* codestyle

* add tests for idefics

* modify general test for VLMs

* no generation test for vlm yet!

* no generation test here also

* wanr in VIT-SDPA if output attn

* add more tests

* user can pass dict as attn impl

* repo consistency

* update

* muicgen

* no prints

* forgot speech enc-dec and clip

* how many composite models we have?

* musicgen meelody is same as mudicgen

* +siglip

* fix tests + add some more

* remove idefics custom overriden code

* make idefics2 automappable

* nits

* skip tests

* doctests

* Update src/transformers/models/idefics2/configuration_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/clip/test_modeling_clip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics2/test_modeling_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics2/test_modeling_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* major update, no need for automap

* clean up

* add FA2 test

* more tests

* style

* skip tests

* why did these started failing now?

* no attributes for FA2 needed

* one tiny test

* address comment about FA2 false warning

* style

* add new models and resolve conflicts

* fix copies

* let it be this way for now, come back tomorrow to review

* some more fixes

* update

* more updates

* update

* fix copies

* style and tests

* another big update

* fix tests

* fix tests

* update

* another update

* fix tests

* fix copies

* fix tests

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-10-22 06:54:44 +02:00
committed by GitHub
parent 32590b5ecb
commit 21d5025826
64 changed files with 1925 additions and 713 deletions

View File

@@ -27,17 +27,24 @@ from transformers.testing_utils import (
require_nltk,
require_sentencepiece,
require_torch,
require_torch_sdpa,
require_vision,
slow,
to_2tuple,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from transformers.utils import (
cached_property,
is_torch_available,
is_vision_available,
)
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from ..bart.test_modeling_bart import BartModelTester
from ..bert.test_modeling_bert import BertModelTester
from ..deit.test_modeling_deit import DeiTModelTester
from ..donut.test_modeling_donut_swin import DonutSwinModelTester
from ..gpt2.test_modeling_gpt2 import GPT2ModelTester
from ..layoutlmv3.test_modeling_layoutlmv3 import LayoutLMv3ModelTester
from ..swin.test_modeling_swin import SwinModelTester
from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester
@@ -53,6 +60,8 @@ if is_torch_available():
BartForCausalLM,
BertLMHeadModel,
DeiTModel,
DonutSwinModel,
GPT2LMHeadModel,
LayoutLMv3Model,
SwinModel,
TrOCRForCausalLM,
@@ -72,6 +81,8 @@ if is_vision_available():
@require_torch
class EncoderDecoderMixin:
supports_sdpa = False
def get_encoder_decoder_model(self, config, decoder_config):
pass
@@ -374,6 +385,69 @@ class EncoderDecoderMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
if not self.supports_sdpa:
self.skipTest("SDPA is not supported")
inputs_dict = self.prepare_config_and_inputs()
encoder_config, decoder_config = inputs_dict["config"], inputs_dict["decoder_config"]
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
encoder_config=encoder_config, decoder_config=decoder_config
)
model = VisionEncoderDecoderModel(config=config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = VisionEncoderDecoderModel.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# see https://github.com/huggingface/transformers/pull/32238
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
encoder_attn = "sdpa" if model.encoder._supports_sdpa else "eager"
decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager"
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.encoder.config._attn_implementation == encoder_attn)
self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn)
# Also test that nothing break if we request SDPA explicitly, when both sub-parts support it.
# If the model supports sdpa (i.e. all of sub-models supports it) we'll dispatch safely
# Otherwise we should raise error that SDPA is not supported, as some of the sub-models doesn't support
if encoder_attn == "sdpa" and decoder_attn == "sdpa":
model_sdpa_explicit = VisionEncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_sdpa_explicit = model_sdpa_explicit.eval().to(torch_device)
self.assertTrue(model_sdpa_explicit.config._attn_implementation == "sdpa")
else:
with self.assertRaises(ValueError):
model_sdpa_explicit = VisionEncoderDecoderModel.from_pretrained(
tmpdirname, attn_implementation="sdpa"
)
model_eager = VisionEncoderDecoderModel.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
@require_torch
class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
@@ -497,6 +571,8 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
@require_torch
class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
supports_sdpa = True # one submodel support SDPA
def get_pretrained_model_and_inputs(self):
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"hf-internal-testing/tiny-random-vit", "hf-internal-testing/tiny-bert"
@@ -649,6 +725,8 @@ class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
@require_torch
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
supports_sdpa = True # one submodel support SDPA
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = ViTModel(config).eval()
decoder_model = TrOCRForCausalLM(decoder_config).eval()
@@ -804,6 +882,240 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
pass
@require_torch
class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
supports_sdpa = True # both submodels support SDPA
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = ViTModel(config).eval()
decoder_model = GPT2LMHeadModel(decoder_config).eval()
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = ViTModelTester(self, batch_size=13)
model_tester_decoder = GPT2ModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, labels = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
decoder_head_mask,
decoder_token_type_ids,
mc_token_ids,
sequence_labels,
token_labels,
choice_labels,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"pixel_values": pixel_values,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"decoder_head_mask": decoder_head_mask,
"labels": decoder_input_ids,
}
def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values,
labels=None,
**kwargs,
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
**kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
seq_len = (encoder_model.config.image_size // encoder_model.config.patch_size) ** 2 + 1
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len), # 4 6 16
)
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Generate until max length
if hasattr(enc_dec_model.config, "eos_token_id"):
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
enc_dec_model.generation_config.eos_token_id = None
enc_dec_model.to(torch_device)
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
@unittest.skip(reason="VIT2GPT2 also has an integration test for testinf save-load")
def test_real_model_save_load_from_pretrained(self):
pass
@require_torch
class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
supports_sdpa = True # one submodel (GPT2) support SDPA
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = DonutSwinModel(config).eval()
decoder_model = GPT2LMHeadModel(decoder_config).eval()
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = DonutSwinModelTester(self, batch_size=13)
model_tester_decoder = GPT2ModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, labels = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
decoder_head_mask,
decoder_token_type_ids,
mc_token_ids,
sequence_labels,
token_labels,
choice_labels,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"pixel_values": pixel_values,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"decoder_head_mask": decoder_head_mask,
"labels": decoder_input_ids,
}
def check_encoder_decoder_model_output_attentions(
self,
config,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values,
labels=None,
**kwargs,
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
**kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
seq_len = encoder_model.config.image_size // encoder_model.config.patch_size
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len), # 4 6 16
)
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Generate until max length
if hasattr(enc_dec_model.config, "eos_token_id"):
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
enc_dec_model.generation_config.eos_token_id = None
enc_dec_model.to(torch_device)
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
@unittest.skip(reason="Donut has an Integration test for that")
def test_real_model_save_load_from_pretrained(self):
pass
@require_vision
@require_torch
class TrOCRModelIntegrationTest(unittest.TestCase):