Owlvit docs test (#18257)

* fix docs and add owlvit docs test

* fix minor bug in post_process, add to processor

* improve owlvit code examples

* fix hardcoded image size
This commit is contained in:
Alara Dirik
2022-07-26 10:55:14 +03:00
committed by GitHub
parent d32558cc7a
commit 002915aa2a
5 changed files with 51 additions and 28 deletions

View File

@@ -39,19 +39,26 @@ OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CL
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = [["a photo of a cat", "a photo of a dog"]]
>>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") >>> inputs = processor(text=texts, images=image, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> logits = outputs["logits"] # Prediction logits of shape [batch_size, num_patches, num_max_text_queries]
>>> boxes = outputs["pred_boxes"] # Object box boundaries of shape [batch_size, num_patches, 4]
>>> batch_size = boxes.shape[0] >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> for i in range(batch_size): # Loop over sets of images and text queries >>> target_sizes = torch.Tensor([image.size[::-1]])
... boxes = outputs["pred_boxes"][i] >>> # Convert outputs (bounding boxes and class logits) to COCO API
... logits = torch.max(outputs["logits"][i], dim=-1) >>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
... scores = torch.sigmoid(logits.values)
... labels = logits.indices >>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
>>> text = texts[i]
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
``` ```
This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit). This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit).

View File

@@ -26,7 +26,6 @@ from ...utils import TensorType, is_torch_available, logging
if is_torch_available(): if is_torch_available():
import torch import torch
from torch import nn
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -109,18 +108,19 @@ class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
in the batch as predicted by the model. in the batch as predicted by the model.
""" """
out_logits, out_bbox = outputs.logits, outputs.pred_boxes logits, boxes = outputs.logits, outputs.pred_boxes
if len(out_logits) != len(target_sizes): if len(logits) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
if target_sizes.shape[1] != 2: if target_sizes.shape[1] != 2:
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
prob = nn.functional.softmax(out_logits, -1) probs = torch.max(logits, dim=-1)
scores, labels = prob[..., :-1].max(-1) scores = torch.sigmoid(probs.values)
labels = probs.indices
# Convert to [x0, y0, x1, y1] format # Convert to [x0, y0, x1, y1] format
boxes = center_to_corners_format(out_bbox) boxes = center_to_corners_format(boxes)
# Convert from relative [0, 1] to absolute [0, height] coordinates # Convert from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1) img_h, img_w = target_sizes.unbind(1)

View File

@@ -1300,23 +1300,31 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
>>> import torch >>> import torch
>>> from transformers import OwlViTProcessor, OwlViTForObjectDetection >>> from transformers import OwlViTProcessor, OwlViTForObjectDetection
>>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
>>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
>>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = [["a photo of a cat", "a photo of a dog"]]
>>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") >>> inputs = processor(text=texts, images=image, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> logits = outputs["logits"] # Prediction logits of shape [batch_size, num_patches, num_max_text_queries]
>>> boxes = outputs["pred_boxes"] # Object box boundaries of shape # [batch_size, num_patches, 4]
>>> batch_size = boxes.shape[0] >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
>>> for i in range(batch_size): # Loop over sets of images and text queries >>> target_sizes = torch.Tensor([image.size[::-1]])
... boxes = outputs["pred_boxes"][i] >>> # Convert outputs (bounding boxes and class logits) to COCO API
... logits = torch.max(outputs["logits"][i], dim=-1) >>> results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
... scores = torch.sigmoid(logits.values)
... labels = logits.indices >>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
>>> text = texts[i]
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
>>> score_threshold = 0.1
>>> for box, score, label in zip(boxes, scores, labels):
... box = [round(i, 2) for i in box.tolist()]
... if score >= score_threshold:
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
Detected a photo of a cat with confidence 0.243 at location [1.42, 50.69, 308.58, 370.48]
Detected a photo of a cat with confidence 0.298 at location [348.06, 20.56, 642.33, 372.61]
```""" ```"""
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@@ -139,6 +139,13 @@ class OwlViTProcessor(ProcessorMixin):
else: else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
def post_process(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OwlViTFeatureExtractor.post_process`]. Please refer to the
docstring of this method for more information.
"""
return self.feature_extractor.post_process(*args, **kwargs)
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@@ -48,6 +48,7 @@ src/transformers/models/mobilevit/modeling_mobilevit.py
src/transformers/models/opt/modeling_opt.py src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_tf_opt.py src/transformers/models/opt/modeling_tf_opt.py
src/transformers/models/opt/modeling_flax_opt.py src/transformers/models/opt/modeling_flax_opt.py
src/transformers/models/owlvit/modeling_owlvit.py
src/transformers/models/pegasus/modeling_pegasus.py src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/plbart/modeling_plbart.py src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py src/transformers/models/poolformer/modeling_poolformer.py