[DINOv2] Convert more checkpoints (#26177)
* Convert checkpoints * Update doc test * Address comment
This commit is contained in:
@@ -508,7 +508,7 @@
|
||||
- local: model_doc/dinat
|
||||
title: DiNAT
|
||||
- local: model_doc/dinov2
|
||||
title: DINO V2
|
||||
title: DINOV2
|
||||
- local: model_doc/dit
|
||||
title: DiT
|
||||
- local: model_doc/dpt
|
||||
|
||||
@@ -19,14 +19,17 @@ URL: https://github.com/facebookresearch/dinov2/tree/main
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from transformers import BitImageProcessor, Dinov2Config, Dinov2Model
|
||||
from transformers import BitImageProcessor, Dinov2Config, Dinov2ForImageClassification, Dinov2Model
|
||||
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
@@ -35,7 +38,7 @@ logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_dinov2_config(model_name):
|
||||
def get_dinov2_config(model_name, image_classifier=False):
|
||||
config = Dinov2Config(image_size=518, patch_size=14)
|
||||
|
||||
# size of the architecture
|
||||
@@ -56,6 +59,13 @@ def get_dinov2_config(model_name):
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
if image_classifier:
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
config.num_labels = 1000
|
||||
config.id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
config.id2label = {int(k): v for k, v in config.id2label.items()}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@@ -140,10 +150,11 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
|
||||
"""
|
||||
|
||||
# define default Dinov2 configuration
|
||||
config = get_dinov2_config(model_name)
|
||||
image_classifier = "1layer" in model_name
|
||||
config = get_dinov2_config(model_name, image_classifier=image_classifier)
|
||||
|
||||
# load original model from torch hub
|
||||
original_model = torch.hub.load("facebookresearch/dinov2", model_name)
|
||||
original_model = torch.hub.load("facebookresearch/dinov2", model_name.replace("_1layer", ""))
|
||||
original_model.eval()
|
||||
|
||||
# load state_dict of original model, remove and rename some keys
|
||||
@@ -162,8 +173,22 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
|
||||
state_dict[key] = val
|
||||
|
||||
# load HuggingFace model
|
||||
model = Dinov2Model(config, add_pooling_layer=False).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
if image_classifier:
|
||||
model = Dinov2ForImageClassification(config).eval()
|
||||
model.dinov2.load_state_dict(state_dict)
|
||||
model_name_to_classifier_dict_url = {
|
||||
"dinov2_vits14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth",
|
||||
"dinov2_vitb14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth",
|
||||
"dinov2_vitl14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth",
|
||||
"dinov2_vitg14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth",
|
||||
}
|
||||
url = model_name_to_classifier_dict_url[model_name]
|
||||
classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
||||
model.classifier.weight = nn.Parameter(classifier_state_dict["weight"])
|
||||
model.classifier.bias = nn.Parameter(classifier_state_dict["bias"])
|
||||
else:
|
||||
model = Dinov2Model(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# load image
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
@@ -195,12 +220,17 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
|
||||
assert torch.allclose(original_pixel_values, pixel_values)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
outputs = model(pixel_values, output_hidden_states=True)
|
||||
original_outputs = original_model(pixel_values)
|
||||
|
||||
# assert values
|
||||
assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape
|
||||
assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3)
|
||||
if image_classifier:
|
||||
print("Predicted class:")
|
||||
class_idx = outputs.logits.argmax(-1).item()
|
||||
print(model.config.id2label[class_idx])
|
||||
else:
|
||||
assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape
|
||||
assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
@@ -216,6 +246,10 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
|
||||
"dinov2_vitb14": "dinov2-base",
|
||||
"dinov2_vitl14": "dinov2-large",
|
||||
"dinov2_vitg14": "dinov2-giant",
|
||||
"dinov2_vits14_1layer": "dinov2-small-imagenet1k-1-layer",
|
||||
"dinov2_vitb14_1layer": "dinov2-base-imagenet1k-1-layer",
|
||||
"dinov2_vitl14_1layer": "dinov2-large-imagenet1k-1-layer",
|
||||
"dinov2_vitg14_1layer": "dinov2-giant-imagenet1k-1-layer",
|
||||
}
|
||||
|
||||
name = model_name_to_hf_name[model_name]
|
||||
@@ -230,7 +264,16 @@ if __name__ == "__main__":
|
||||
"--model_name",
|
||||
default="dinov2_vitb14",
|
||||
type=str,
|
||||
choices=["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov2_vitg14"],
|
||||
choices=[
|
||||
"dinov2_vits14",
|
||||
"dinov2_vitb14",
|
||||
"dinov2_vitl14",
|
||||
"dinov2_vitg14",
|
||||
"dinov2_vits14_1layer",
|
||||
"dinov2_vitb14_1layer",
|
||||
"dinov2_vitl14_1layer",
|
||||
"dinov2_vitg14_1layer",
|
||||
],
|
||||
help="Name of the model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -54,7 +54,8 @@ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
|
||||
|
||||
# Image classification docstring
|
||||
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
|
||||
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-small-imagenet1k-1-layer"
|
||||
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
||||
|
||||
|
||||
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
@@ -693,6 +694,7 @@ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
|
||||
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||
output_type=ImageClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user