accelerate support for OwlViT (#20411)
* `accelerate` support for `OwlViT` - added `accelerate` support - added slow `fp16` tests * apply suggestions
This commit is contained in:
@@ -434,6 +434,9 @@ class OwlViTAttention(nn.Module):
|
|||||||
|
|
||||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# For int8 compatibility, sometimes the `attn_probs` are in `fp32`
|
||||||
|
attn_probs = attn_probs.to(value_states.dtype)
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
@@ -528,6 +531,7 @@ class OwlViTPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "owlvit"
|
base_model_prefix = "owlvit"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
_no_split_modules = ["OwlViTEncoderLayer"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
@@ -836,7 +840,8 @@ class OwlViTTextTransformer(nn.Module):
|
|||||||
# take features from the end of tokens embedding (end of token is the highest number in each sequence)
|
# take features from the end of tokens embedding (end of token is the highest number in each sequence)
|
||||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||||
pooled_output = last_hidden_state[
|
pooled_output = last_hidden_state[
|
||||||
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
|
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||||
|
input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device),
|
||||||
]
|
]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
@@ -939,8 +944,13 @@ class OwlViTVisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# Cast the input to the expected `dtype`
|
||||||
|
expected_input_dtype = self.embeddings.patch_embedding.weight.dtype
|
||||||
|
pixel_values = pixel_values.to(expected_input_dtype)
|
||||||
|
|
||||||
hidden_states = self.embeddings(pixel_values)
|
hidden_states = self.embeddings(pixel_values)
|
||||||
hidden_states = self.pre_layernorm(hidden_states)
|
hidden_states = self.pre_layernorm(hidden_states)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -1193,8 +1203,9 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||||||
image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
|
image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
|
||||||
text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
|
text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
# cosine similarity as logits
|
# cosine similarity as logits and set it on the correct device
|
||||||
logit_scale = self.logit_scale.exp()
|
logit_scale = self.logit_scale.exp().to(image_embeds.device)
|
||||||
|
|
||||||
logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale
|
logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale
|
||||||
logits_per_image = logits_per_text.t()
|
logits_per_image = logits_per_text.t()
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import numpy as np
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
|
from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
|
||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -778,3 +778,28 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
|||||||
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
|
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_inference_one_shot_object_detection_fp16(self):
|
||||||
|
model_name = "google/owlvit-base-patch32"
|
||||||
|
model = OwlViTForObjectDetection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
|
||||||
|
|
||||||
|
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
query_image = prepare_img()
|
||||||
|
inputs = processor(
|
||||||
|
images=image,
|
||||||
|
query_images=query_image,
|
||||||
|
max_length=16,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.image_guided_detection(**inputs)
|
||||||
|
|
||||||
|
# No need to check the logits, we just check inference runs fine.
|
||||||
|
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
|
||||||
|
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||||
|
|||||||
Reference in New Issue
Block a user