From a1d4563f7a2d78a5b29e4da46c76c90c4afe5331 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 25 Nov 2022 11:20:44 +0100 Subject: [PATCH] `accelerate` support for `OwlViT` (#20411) * `accelerate` support for `OwlViT` - added `accelerate` support - added slow `fp16` tests * apply suggestions --- .../models/owlvit/modeling_owlvit.py | 17 +++++++++--- tests/models/owlvit/test_modeling_owlvit.py | 27 ++++++++++++++++++- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 2928a2f7e2..fd0f30b79d 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -434,6 +434,9 @@ class OwlViTAttention(nn.Module): attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + # For int8 compatibility, sometimes the `attn_probs` are in `fp32` + attn_probs = attn_probs.to(value_states.dtype) + attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): @@ -528,6 +531,7 @@ class OwlViTPreTrainedModel(PreTrainedModel): base_model_prefix = "owlvit" supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] + _no_split_modules = ["OwlViTEncoderLayer"] def _init_weights(self, module): """Initialize the weights""" @@ -836,7 +840,8 @@ class OwlViTTextTransformer(nn.Module): # take features from the end of tokens embedding (end of token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device), ] if not return_dict: @@ -939,8 +944,13 @@ class OwlViTVisionTransformer(nn.Module): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # Cast the input to the expected `dtype` + expected_input_dtype = self.embeddings.patch_embedding.weight.dtype + pixel_values = pixel_values.to(expected_input_dtype) + hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layernorm(hidden_states) + encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, @@ -1193,8 +1203,9 @@ class OwlViTModel(OwlViTPreTrainedModel): image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True) text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True) - # cosine similarity as logits - logit_scale = self.logit_scale.exp() + # cosine similarity as logits and set it on the correct device + logit_scale = self.logit_scale.exp().to(image_embeds.device) + logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale logits_per_image = logits_per_text.t() diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py index f492d85e67..9575339801 100644 --- a/tests/models/owlvit/test_modeling_owlvit.py +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -24,7 +24,7 @@ import numpy as np import requests from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device from transformers.utils import is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -778,3 +778,28 @@ class OwlViTModelIntegrationTest(unittest.TestCase): [[0.0691, 0.0445, 0.1373], [0.1592, 0.0456, 0.3192], [0.1632, 0.0423, 0.2478]] ).to(torch_device) self.assertTrue(torch.allclose(outputs.target_pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) + + @slow + @require_torch_gpu + def test_inference_one_shot_object_detection_fp16(self): + model_name = "google/owlvit-base-patch32" + model = OwlViTForObjectDetection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device) + + processor = OwlViTProcessor.from_pretrained(model_name) + + image = prepare_img() + query_image = prepare_img() + inputs = processor( + images=image, + query_images=query_image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model.image_guided_detection(**inputs) + + # No need to check the logits, we just check inference runs fine. + num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2) + self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))