From 7e86cb6c6f126b5d283d61b48e1879023c11086a Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 25 Jun 2024 09:49:55 +0500 Subject: [PATCH] Siglip: add `_no_split_module` (#31566) * device-map siglip * move split modules to PretrainedSigLip --- src/transformers/models/siglip/modeling_siglip.py | 15 +++++++++++---- tests/models/siglip/test_modeling_siglip.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index d605f49261..4c534bbce6 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -496,6 +496,13 @@ class SiglipPreTrainedModel(PreTrainedModel): config_class = SiglipConfig base_model_prefix = "siglip" supports_gradient_checkpointing = True + _no_split_modules = [ + "SiglipTextEmbeddings", + "SiglipEncoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] def _init_weights(self, module): """Initialize the weights""" @@ -816,8 +823,6 @@ class SiglipTextTransformer(nn.Module): class SiglipTextModel(SiglipPreTrainedModel): config_class = SiglipTextConfig - _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"] - def __init__(self, config: SiglipTextConfig): super().__init__(config) self.text_model = SiglipTextTransformer(config) @@ -959,7 +964,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): class SiglipVisionModel(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "pixel_values" - _no_split_modules = ["SiglipVisionEmbeddings", "SiglipEncoderLayer", "SiglipMultiheadAttentionPoolingHead"] def __init__(self, config: SiglipVisionConfig): super().__init__(config) @@ -1222,7 +1226,10 @@ class SiglipModel(SiglipPreTrainedModel): text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits - logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias + logits_per_text = ( + torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp() + + self.logit_bias + ) logits_per_image = logits_per_text.t() loss = None diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index 12ac11251d..af5d0bf2bc 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -443,6 +443,12 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): test_pruning = False test_resize_embeddings = False test_attention_outputs = False + # MP works but offload doesn't work when the MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False # Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip def setUp(self): @@ -618,6 +624,12 @@ class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixi test_pruning = False test_resize_embeddings = False test_attention_outputs = False + # MP works but offload doesn't work when the MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False def setUp(self): self.model_tester = SiglipForImageClassificationModelTester(self)