Fix code snippet for Grounding DINO (#32229)
Fix code snippet for grounding-dino
This commit is contained in:
committed by
GitHub
parent
3a83ec48a6
commit
9d6c0641c4
@@ -41,33 +41,40 @@ The original code can be found [here](https://github.com/IDEA-Research/Grounding
|
|||||||
Here's how to use the model for zero-shot object detection:
|
Here's how to use the model for zero-shot object detection:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import requests
|
>>> import requests
|
||||||
|
|
||||||
import torch
|
>>> import torch
|
||||||
from PIL import Image
|
>>> from PIL import Image
|
||||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection,
|
>>> from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||||
|
|
||||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
>>> model_id = "IDEA-Research/grounding-dino-tiny"
|
||||||
|
>>> device = "cuda"
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained(model_id)
|
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||||
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
||||||
|
|
||||||
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
>>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
image = Image.open(requests.get(image_url, stream=True).raw)
|
>>> image = Image.open(requests.get(image_url, stream=True).raw)
|
||||||
# Check for cats and remote controls
|
>>> # Check for cats and remote controls
|
||||||
text = "a cat. a remote control."
|
>>> text = "a cat. a remote control."
|
||||||
|
|
||||||
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
>>> inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
||||||
with torch.no_grad():
|
>>> with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
... outputs = model(**inputs)
|
||||||
|
|
||||||
results = processor.post_process_grounded_object_detection(
|
>>> results = processor.post_process_grounded_object_detection(
|
||||||
outputs,
|
... outputs,
|
||||||
inputs.input_ids,
|
... inputs.input_ids,
|
||||||
box_threshold=0.4,
|
... box_threshold=0.4,
|
||||||
text_threshold=0.3,
|
... text_threshold=0.3,
|
||||||
target_sizes=[image.size[::-1]]
|
... target_sizes=[image.size[::-1]]
|
||||||
)
|
... )
|
||||||
|
>>> print(results)
|
||||||
|
[{'boxes': tensor([[344.6959, 23.1090, 637.1833, 374.2751],
|
||||||
|
[ 12.2666, 51.9145, 316.8582, 472.4392],
|
||||||
|
[ 38.5742, 70.0015, 176.7838, 118.1806]], device='cuda:0'),
|
||||||
|
'labels': ['a cat', 'a cat', 'a remote control'],
|
||||||
|
'scores': tensor([0.4785, 0.4381, 0.4776], device='cuda:0')}]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Grounded SAM
|
## Grounded SAM
|
||||||
|
|||||||
Reference in New Issue
Block a user