[DINOv2] Convert more checkpoints (#26177)
* Convert checkpoints * Update doc test * Address comment
This commit is contained in:
@@ -508,7 +508,7 @@
|
|||||||
- local: model_doc/dinat
|
- local: model_doc/dinat
|
||||||
title: DiNAT
|
title: DiNAT
|
||||||
- local: model_doc/dinov2
|
- local: model_doc/dinov2
|
||||||
title: DINO V2
|
title: DINOV2
|
||||||
- local: model_doc/dit
|
- local: model_doc/dit
|
||||||
title: DiT
|
title: DiT
|
||||||
- local: model_doc/dpt
|
- local: model_doc/dpt
|
||||||
|
|||||||
@@ -19,14 +19,17 @@ URL: https://github.com/facebookresearch/dinov2/tree/main
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from transformers import BitImageProcessor, Dinov2Config, Dinov2Model
|
from transformers import BitImageProcessor, Dinov2Config, Dinov2ForImageClassification, Dinov2Model
|
||||||
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling
|
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
@@ -35,7 +38,7 @@ logging.set_verbosity_info()
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_dinov2_config(model_name):
|
def get_dinov2_config(model_name, image_classifier=False):
|
||||||
config = Dinov2Config(image_size=518, patch_size=14)
|
config = Dinov2Config(image_size=518, patch_size=14)
|
||||||
|
|
||||||
# size of the architecture
|
# size of the architecture
|
||||||
@@ -56,6 +59,13 @@ def get_dinov2_config(model_name):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Model not supported")
|
raise ValueError("Model not supported")
|
||||||
|
|
||||||
|
if image_classifier:
|
||||||
|
repo_id = "huggingface/label-files"
|
||||||
|
filename = "imagenet-1k-id2label.json"
|
||||||
|
config.num_labels = 1000
|
||||||
|
config.id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||||
|
config.id2label = {int(k): v for k, v in config.id2label.items()}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
@@ -140,10 +150,11 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# define default Dinov2 configuration
|
# define default Dinov2 configuration
|
||||||
config = get_dinov2_config(model_name)
|
image_classifier = "1layer" in model_name
|
||||||
|
config = get_dinov2_config(model_name, image_classifier=image_classifier)
|
||||||
|
|
||||||
# load original model from torch hub
|
# load original model from torch hub
|
||||||
original_model = torch.hub.load("facebookresearch/dinov2", model_name)
|
original_model = torch.hub.load("facebookresearch/dinov2", model_name.replace("_1layer", ""))
|
||||||
original_model.eval()
|
original_model.eval()
|
||||||
|
|
||||||
# load state_dict of original model, remove and rename some keys
|
# load state_dict of original model, remove and rename some keys
|
||||||
@@ -162,8 +173,22 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
|
|||||||
state_dict[key] = val
|
state_dict[key] = val
|
||||||
|
|
||||||
# load HuggingFace model
|
# load HuggingFace model
|
||||||
model = Dinov2Model(config, add_pooling_layer=False).eval()
|
if image_classifier:
|
||||||
model.load_state_dict(state_dict)
|
model = Dinov2ForImageClassification(config).eval()
|
||||||
|
model.dinov2.load_state_dict(state_dict)
|
||||||
|
model_name_to_classifier_dict_url = {
|
||||||
|
"dinov2_vits14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth",
|
||||||
|
"dinov2_vitb14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth",
|
||||||
|
"dinov2_vitl14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth",
|
||||||
|
"dinov2_vitg14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth",
|
||||||
|
}
|
||||||
|
url = model_name_to_classifier_dict_url[model_name]
|
||||||
|
classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
||||||
|
model.classifier.weight = nn.Parameter(classifier_state_dict["weight"])
|
||||||
|
model.classifier.bias = nn.Parameter(classifier_state_dict["bias"])
|
||||||
|
else:
|
||||||
|
model = Dinov2Model(config).eval()
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
# load image
|
# load image
|
||||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
@@ -195,12 +220,17 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
|
|||||||
assert torch.allclose(original_pixel_values, pixel_values)
|
assert torch.allclose(original_pixel_values, pixel_values)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(pixel_values)
|
outputs = model(pixel_values, output_hidden_states=True)
|
||||||
original_outputs = original_model(pixel_values)
|
original_outputs = original_model(pixel_values)
|
||||||
|
|
||||||
# assert values
|
# assert values
|
||||||
assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape
|
if image_classifier:
|
||||||
assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3)
|
print("Predicted class:")
|
||||||
|
class_idx = outputs.logits.argmax(-1).item()
|
||||||
|
print(model.config.id2label[class_idx])
|
||||||
|
else:
|
||||||
|
assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape
|
||||||
|
assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3)
|
||||||
print("Looks ok!")
|
print("Looks ok!")
|
||||||
|
|
||||||
if pytorch_dump_folder_path is not None:
|
if pytorch_dump_folder_path is not None:
|
||||||
@@ -216,6 +246,10 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
|
|||||||
"dinov2_vitb14": "dinov2-base",
|
"dinov2_vitb14": "dinov2-base",
|
||||||
"dinov2_vitl14": "dinov2-large",
|
"dinov2_vitl14": "dinov2-large",
|
||||||
"dinov2_vitg14": "dinov2-giant",
|
"dinov2_vitg14": "dinov2-giant",
|
||||||
|
"dinov2_vits14_1layer": "dinov2-small-imagenet1k-1-layer",
|
||||||
|
"dinov2_vitb14_1layer": "dinov2-base-imagenet1k-1-layer",
|
||||||
|
"dinov2_vitl14_1layer": "dinov2-large-imagenet1k-1-layer",
|
||||||
|
"dinov2_vitg14_1layer": "dinov2-giant-imagenet1k-1-layer",
|
||||||
}
|
}
|
||||||
|
|
||||||
name = model_name_to_hf_name[model_name]
|
name = model_name_to_hf_name[model_name]
|
||||||
@@ -230,7 +264,16 @@ if __name__ == "__main__":
|
|||||||
"--model_name",
|
"--model_name",
|
||||||
default="dinov2_vitb14",
|
default="dinov2_vitb14",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov2_vitg14"],
|
choices=[
|
||||||
|
"dinov2_vits14",
|
||||||
|
"dinov2_vitb14",
|
||||||
|
"dinov2_vitl14",
|
||||||
|
"dinov2_vitg14",
|
||||||
|
"dinov2_vits14_1layer",
|
||||||
|
"dinov2_vitb14_1layer",
|
||||||
|
"dinov2_vitl14_1layer",
|
||||||
|
"dinov2_vitg14_1layer",
|
||||||
|
],
|
||||||
help="Name of the model you'd like to convert.",
|
help="Name of the model you'd like to convert.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -54,7 +54,8 @@ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
|
|||||||
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
|
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
|
||||||
|
|
||||||
# Image classification docstring
|
# Image classification docstring
|
||||||
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
|
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-small-imagenet1k-1-layer"
|
||||||
|
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
||||||
|
|
||||||
|
|
||||||
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
@@ -693,6 +694,7 @@ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
|
|||||||
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||||
output_type=ImageClassifierOutput,
|
output_type=ImageClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user