Fix code snippet for Grounding DINO (#32229)
Fix code snippet for grounding-dino
This commit is contained in:
committed by
GitHub
parent
3a83ec48a6
commit
9d6c0641c4
@@ -41,33 +41,40 @@ The original code can be found [here](https://github.com/IDEA-Research/Grounding
|
||||
Here's how to use the model for zero-shot object detection:
|
||||
|
||||
```python
|
||||
import requests
|
||||
>>> import requests
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection,
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
|
||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
||||
>>> model_id = "IDEA-Research/grounding-dino-tiny"
|
||||
>>> device = "cuda"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
||||
|
||||
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(image_url, stream=True).raw)
|
||||
# Check for cats and remote controls
|
||||
text = "a cat. a remote control."
|
||||
>>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(image_url, stream=True).raw)
|
||||
>>> # Check for cats and remote controls
|
||||
>>> text = "a cat. a remote control."
|
||||
|
||||
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
>>> inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
inputs.input_ids,
|
||||
box_threshold=0.4,
|
||||
text_threshold=0.3,
|
||||
target_sizes=[image.size[::-1]]
|
||||
)
|
||||
>>> results = processor.post_process_grounded_object_detection(
|
||||
... outputs,
|
||||
... inputs.input_ids,
|
||||
... box_threshold=0.4,
|
||||
... text_threshold=0.3,
|
||||
... target_sizes=[image.size[::-1]]
|
||||
... )
|
||||
>>> print(results)
|
||||
[{'boxes': tensor([[344.6959, 23.1090, 637.1833, 374.2751],
|
||||
[ 12.2666, 51.9145, 316.8582, 472.4392],
|
||||
[ 38.5742, 70.0015, 176.7838, 118.1806]], device='cuda:0'),
|
||||
'labels': ['a cat', 'a cat', 'a remote control'],
|
||||
'scores': tensor([0.4785, 0.4381, 0.4776], device='cuda:0')}]
|
||||
```
|
||||
|
||||
## Grounded SAM
|
||||
|
||||
Reference in New Issue
Block a user