accelerate support for OwlViT (#20411)
* `accelerate` support for `OwlViT` - added `accelerate` support - added slow `fp16` tests * apply suggestions
This commit is contained in:
@@ -434,6 +434,9 @@ class OwlViTAttention(nn.Module):
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
# For int8 compatibility, sometimes the `attn_probs` are in `fp32`
|
||||
attn_probs = attn_probs.to(value_states.dtype)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
@@ -528,6 +531,7 @@ class OwlViTPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "owlvit"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
_no_split_modules = ["OwlViTEncoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@@ -836,7 +840,8 @@ class OwlViTTextTransformer(nn.Module):
|
||||
# take features from the end of tokens embedding (end of token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device),
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
@@ -939,8 +944,13 @@ class OwlViTVisionTransformer(nn.Module):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Cast the input to the expected `dtype`
|
||||
expected_input_dtype = self.embeddings.patch_embedding.weight.dtype
|
||||
pixel_values = pixel_values.to(expected_input_dtype)
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layernorm(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
@@ -1193,8 +1203,9 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
||||
image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
|
||||
text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
# cosine similarity as logits and set it on the correct device
|
||||
logit_scale = self.logit_scale.exp().to(image_embeds.device)
|
||||
|
||||
logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import numpy as np
|
||||
|
||||
import requests
|
||||
from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -778,3 +778,28 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
[[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_inference_one_shot_object_detection_fp16(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
model = OwlViTForObjectDetection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
image = prepare_img()
|
||||
query_image = prepare_img()
|
||||
inputs = processor(
|
||||
images=image,
|
||||
query_images=query_image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.image_guided_detection(**inputs)
|
||||
|
||||
# No need to check the logits, we just check inference runs fine.
|
||||
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
Reference in New Issue
Block a user