@@ -90,11 +90,6 @@ class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
|
|||||||
|
|
||||||
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
||||||
|
|
||||||
if self.use_rel_pos:
|
|
||||||
attn_weights = self.add_decomposed_rel_pos(
|
|
||||||
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
||||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
||||||
@@ -114,13 +109,14 @@ Load the model with [`~PreTrainedModel.from_pretrained`].
|
|||||||
|
|
||||||
```py
|
```py
|
||||||
from transformers import SamModel
|
from transformers import SamModel
|
||||||
from transformers.models.sam import modeling_sam
|
|
||||||
|
|
||||||
# replace the attention class in the modeling_sam module
|
|
||||||
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
|
|
||||||
|
|
||||||
# load the pretrained SAM model
|
# load the pretrained SAM model
|
||||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||||
|
|
||||||
|
# replace the attention class in the vision_encoder module
|
||||||
|
for layer in model.vision_encoder.layers:
|
||||||
|
if hasattr(layer, "attn"):
|
||||||
|
layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)
|
||||||
```
|
```
|
||||||
|
|
||||||
## LoRA
|
## LoRA
|
||||||
@@ -138,7 +134,7 @@ config = LoraConfig(
|
|||||||
# apply LoRA to q and v
|
# apply LoRA to q and v
|
||||||
target_modules=["q", "v"],
|
target_modules=["q", "v"],
|
||||||
lora_dropout=0.1,
|
lora_dropout=0.1,
|
||||||
task_type="mask-generation"
|
task_type="FEATURE_EXTRACTION"
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -152,5 +148,5 @@ Call [print_trainable_parameters](https://huggingface.co/docs/peft/package_refer
|
|||||||
|
|
||||||
```py
|
```py
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
"trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447"
|
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"
|
||||||
```
|
```
|
||||||
Reference in New Issue
Block a user