From 9ecee14378271e96b6686c6e25996db2abf9994a Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 20 May 2025 01:37:54 +0800 Subject: [PATCH] [doc] fix bugs in `how_to_hack_models.md` (#38198) fix several bugs --- docs/source/en/how_to_hack_models.md | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/docs/source/en/how_to_hack_models.md b/docs/source/en/how_to_hack_models.md index cc229f6b01..0a3c38a3e1 100644 --- a/docs/source/en/how_to_hack_models.md +++ b/docs/source/en/how_to_hack_models.md @@ -90,11 +90,6 @@ class SamVisionAttentionSplit(SamVisionAttention, nn.Module): attn_weights = (query * self.scale) @ key.transpose(-2, -1) - if self.use_rel_pos: - attn_weights = self.add_decomposed_rel_pos( - attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) - ) - attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) @@ -114,13 +109,14 @@ Load the model with [`~PreTrainedModel.from_pretrained`]. ```py from transformers import SamModel -from transformers.models.sam import modeling_sam - -# replace the attention class in the modeling_sam module -modeling_sam.SamVisionAttention = SamVisionAttentionSplit # load the pretrained SAM model model = SamModel.from_pretrained("facebook/sam-vit-base") + +# replace the attention class in the vision_encoder module +for layer in model.vision_encoder.layers: + if hasattr(layer, "attn"): + layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size) ``` ## LoRA @@ -138,7 +134,7 @@ config = LoraConfig( # apply LoRA to q and v target_modules=["q", "v"], lora_dropout=0.1, - task_type="mask-generation" + task_type="FEATURE_EXTRACTION" ) ``` @@ -152,5 +148,5 @@ Call [print_trainable_parameters](https://huggingface.co/docs/peft/package_refer ```py model.print_trainable_parameters() -"trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447" +"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256" ``` \ No newline at end of file