Attn implementation for composite models (#32238)

* first try

* codestyle

* idefics2 is happy

* [run-slow] llava, llava_next, video_llava, vipllava, llava_next_video, idefics, idefics2, kosmos2, fuyu, blip, blip_2, instructblip, instructblipvideo, paligemma

* fix-copies

* [run-slow] llava, llava_next, video_llava, vipllava, llava_next_video, idefics, idefics2, kosmos2, fuyu, blip, blip_2, instructblip, instructblipvideo

* blip-2 needs to init vision from config

* when was this removed O_o

* minor fix

* tests

* this way?

* tests

* model-agnostic code

* codestyle

* add tests for idefics

* modify general test for VLMs

* no generation test for vlm yet!

* no generation test here also

* wanr in VIT-SDPA if output attn

* add more tests

* user can pass dict as attn impl

* repo consistency

* update

* muicgen

* no prints

* forgot speech enc-dec and clip

* how many composite models we have?

* musicgen meelody is same as mudicgen

* +siglip

* fix tests + add some more

* remove idefics custom overriden code

* make idefics2 automappable

* nits

* skip tests

* doctests

* Update src/transformers/models/idefics2/configuration_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/clip/test_modeling_clip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics2/test_modeling_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics2/test_modeling_idefics2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* major update, no need for automap

* clean up

* add FA2 test

* more tests

* style

* skip tests

* why did these started failing now?

* no attributes for FA2 needed

* one tiny test

* address comment about FA2 false warning

* style

* add new models and resolve conflicts

* fix copies

* let it be this way for now, come back tomorrow to review

* some more fixes

* update

* more updates

* update

* fix copies

* style and tests

* another big update

* fix tests

* fix tests

* update

* another update

* fix tests

* fix copies

* fix tests

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-10-22 06:54:44 +02:00
committed by GitHub
parent 32590b5ecb
commit 21d5025826
64 changed files with 1925 additions and 713 deletions

View File

@@ -207,6 +207,7 @@ class ModelTesterMixin:
test_model_parallel = False
is_encoder_decoder = False
has_attentions = True
_is_composite = False
model_split_percents = [0.5, 0.7, 0.9]
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
@@ -3006,6 +3007,7 @@ class ModelTesterMixin:
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
]:
continue
model = model_class(config)
model.to(torch_device)
model.eval()
@@ -3950,6 +3952,147 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(out, out_fa))
def test_attn_implementation_composite_models(self):
"""
Tests if composite models can receive a dict object as attn_implementation, where each key should be
one of the sub-configs from the model's config.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
if not self._is_composite:
self.skipTest("Model is not a composite model.")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
sub_configs = {
key: getattr(config, key) for key in config if isinstance(getattr(config, key), PretrainedConfig)
}
# set eager as it will be the one supported in all models
# we just need to test if passing 'attn_implementation' as a dict fails or not
attn_implementation_per_subconfig = {}
for key, sub_config in sub_configs.items():
attn_implementation_per_subconfig[key] = "eager"
config._attn_implementation = attn_implementation_per_subconfig
model = model_class(config)
for key in model.config:
if isinstance(getattr(model.config, key), PretrainedConfig):
sub_config = getattr(model.config, key)
self.assertTrue(sub_config._attn_implementation == "eager")
for name, submodule in model.named_modules():
class_name = submodule.__class__.__name__
if (
"SdpaAttention" in class_name
or "SdpaSelfAttention" in class_name
or "FlashAttention" in class_name
):
raise ValueError("The eager model should not have SDPA/FA2 attention layers")
@require_torch_sdpa
def test_sdpa_can_dispatch_non_composite_models(self):
"""
Tests if non-composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa or self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
See https://github.com/huggingface/transformers/pull/32238 for more info
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
that has a different set of sub-configs has to overwrite this test.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
vision_model_names = {"visual", "image_tower", "vision_tower", "vision_model"}
language_model_names = {"language_model", "model", "text_model"}
vision_model_name = [name for name in vision_model_names if hasattr(model_sdpa, name)][0]
language_model_name = [name for name in language_model_names if hasattr(model_sdpa, name)][0]
vision_model_sdpa = getattr(model, vision_model_name)
language_model_sdpa = getattr(model, language_model_name)
text_attn = "sdpa" if language_model_sdpa._supports_sdpa else "eager"
vision_attn = "sdpa" if vision_model_sdpa._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(language_model_sdpa.config._attn_implementation == text_attn)
self.assertTrue(vision_model_sdpa.config._attn_implementation == vision_attn)
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(getattr(model_eager, language_model_name).config._attn_implementation == "eager")
self.assertTrue(getattr(model_eager, vision_model_name).config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(module_attn == "sdpa" for module_attn in [text_attn, vision_attn]):
raise ValueError("The SDPA model should have SDPA attention layers")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@@ -4012,7 +4155,6 @@ class ModelTesterMixin:
# This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
# However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -4020,8 +4162,6 @@ class ModelTesterMixin:
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
@@ -4029,22 +4169,6 @@ class ModelTesterMixin:
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
@@ -4279,7 +4403,7 @@ class ModelTesterMixin:
self.skipTest(
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
)
if config.model_type in ["idefics"]:
if config.model_type in ["idefics", "idefics2", "idefics3"]:
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
model = model_class(config)
@@ -4382,8 +4506,6 @@ class ModelTesterMixin:
low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
@@ -4391,22 +4513,6 @@ class ModelTesterMixin:
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
@@ -4429,6 +4535,8 @@ class ModelTesterMixin:
self.skipTest(f"No generative model classes for {self.__class__.__name__}")
for model_class in self.all_generative_model_classes:
if model_class._supports_sdpa:
self.skipTest(reason="Model architecture does not support attentions")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.model_type not in WINDOW_ATTENTION_MODELS:
@@ -4531,6 +4639,62 @@ class ModelTesterMixin:
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_can_dispatch_composite_models(self):
"""
Tests if composite models can dispatch on FA2 if the sub-models support FA2.
The tests is needed as we handle differently composite models and we cannot check them
with above tests. If any of the sub-models does not support FA2, we'll raise an error when dispatching
that particular sub-model. Otherwise we dispatch safely in all sub-models, where "sub-models" are specific
backbone models (LM/vision/audio/etc)
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
torch_dtype = torch.float16
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
if not self._is_composite:
self.skipTest("This model is not a composte model!")
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
supports_fa2_all_modules = all(
module._supports_flash_attn_2
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
)
if not supports_fa2_all_modules:
with self.assertRaises(ValueError):
model_fa2 = model_class.from_pretrained(
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
)
else:
model_fa2 = model_class.from_pretrained(
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
)
for key in model_fa2.config:
if isinstance(getattr(model_fa2.config, key), PretrainedConfig):
sub_config = getattr(model_fa2.config, key)
self.assertTrue(sub_config._attn_implementation == "flash_attention_2")
has_fa2 = False
for name, submodule in model_fa2.named_modules():
class_name = submodule.__class__.__name__
if "FlashAttention" in class_name:
has_fa2 = True
break
if not has_fa2:
raise ValueError("The FA2 model should have FA2 layers")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -4679,7 +4843,7 @@ class ModelTesterMixin:
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(