From 0759f2510ccdbf79d0c77c85a572516eaf7406a1 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Thu, 26 Aug 2021 17:25:20 +0200 Subject: [PATCH] Add DINO conversion script (#13265) * First commit * Add interpolation of patch embeddings * Comment out code * Fix bug * Fix another bug * Fix bug * Fix another bug * Remove print statements * Update conversion script * Use the official vit implementation * Add support for converting dino_vits8 * Add DINO to docs of ViT * Remove assertion * Add interpolation of position encodings * Fix bug * Add align_corners * Add interpolate_pos_encoding option to forward pass of ViTModel * Improve interpolate_pos_encoding method * Add docstring --- docs/source/model_doc/vit.rst | 17 ++ src/transformers/models/deit/modeling_deit.py | 1 - .../models/vit/convert_dino_to_pytorch.py | 219 ++++++++++++++++++ src/transformers/models/vit/modeling_vit.py | 67 +++++- 4 files changed, 292 insertions(+), 12 deletions(-) create mode 100644 src/transformers/models/vit/convert_dino_to_pytorch.py diff --git a/docs/source/model_doc/vit.rst b/docs/source/model_doc/vit.rst index 3c396c54eb..f06da21275 100644 --- a/docs/source/model_doc/vit.rst +++ b/docs/source/model_doc/vit.rst @@ -66,6 +66,23 @@ Tips: language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant improvement of 2% to training from scratch, but still 4% behind supervised pre-training. +Following the original Vision Transformer, some follow-up works have been made: + +- DeiT (Data-efficient Image Transformers) by Facebook AI. DeiT models are distilled vision transformers. Refer to + :doc:`DeiT's documentation page `. The authors of DeiT also released more efficiently trained ViT models, which + you can directly plug into :class:`~transformers.ViTModel` or :class:`~transformers.ViTForImageClassification`. There + are 4 variants available (in 3 different sizes): `facebook/deit-tiny-patch16-224`, `facebook/deit-small-patch16-224`, + `facebook/deit-base-patch16-224` and `facebook/deit-base-patch16-384`. Note that one should use + :class:`~transformers.DeiTFeatureExtractor` in order to prepare images for the model. + +- BEiT (BERT pre-training of Image Transformers) by Microsoft Research. BEiT models outperform supervised pre-trained + vision transformers using a self-supervised method inspired by BERT (masked image modeling) and based on a VQ-VAE. + Refer to :doc:`BEiT's documentation page `. + +- DINO (a method for self-supervised training of Vision Transformers) by Facebook AI. Vision Transformers trained using + the DINO method show very interesting properties not seen with convolutional models. They are capable of segmenting + objects, without having ever been trained to do so. DINO checkpoints can be found on the `hub + `__. This model was contributed by `nielsr `__. The original code (written in JAX) can be found `here `__. diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index a415596471..b848376817 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -93,7 +93,6 @@ class DeiTEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.vit.modeling_vit.PatchEmbeddings class PatchEmbeddings(nn.Module): """ Image to Patch Embedding. diff --git a/src/transformers/models/vit/convert_dino_to_pytorch.py b/src/transformers/models/vit/convert_dino_to_pytorch.py new file mode 100644 index 0000000000..e69ab77976 --- /dev/null +++ b/src/transformers/models/vit/convert_dino_to_pytorch.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ViT checkpoints trained with the DINO method.""" + + +import argparse +import json +from pathlib import Path + +import torch +from PIL import Image + +import requests +from huggingface_hub import cached_download, hf_hub_url +from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, base_model=False): + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias")) + + # projection layer + position embeddings + rename_keys.extend( + [ + ("cls_token", "vit.embeddings.cls_token"), + ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"), + ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"), + ("pos_embed", "vit.embeddings.position_embeddings"), + ] + ) + + if base_model: + # layernorm + pooler + rename_keys.extend( + [ + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ] + ) + + # if just the base model, we should remove "vit" from all keys that start with "vit" + rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys] + else: + # layernorm + classification head + rename_keys.extend( + [ + ("norm.weight", "vit.layernorm.weight"), + ("norm.bias", "vit.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, base_model=False): + for i in range(config.num_hidden_layers): + if base_model: + prefix = "" + else: + prefix = "vit." + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +def remove_classification_head_(state_dict): + ignore_keys = ["head.weight", "head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_vit_checkpoint(model_name, pytorch_dump_folder_path, base_model=True): + """ + Copy/paste/tweak model's weights to our ViT structure. + """ + + # define default ViT configuration + config = ViTConfig() + # patch_size + if model_name[-1] == "8": + config.patch_size = 8 + # set labels if required + if not base_model: + config.num_labels = 1000 + repo_id = "datasets/huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + # size of the architecture + if model_name in ["dino_vits8", "dino_vits16"]: + config.hidden_size = 384 + config.intermediate_size = 1536 + config.num_hidden_layers = 12 + config.num_attention_heads = 6 + + # load original model from torch hub + original_model = torch.hub.load("facebookresearch/dino:main", model_name) + original_model.eval() + + # load state_dict of original model, remove and rename some keys + state_dict = original_model.state_dict() + if base_model: + remove_classification_head_(state_dict) + rename_keys = create_rename_keys(config, base_model=base_model) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, base_model) + + # load HuggingFace model + if base_model: + model = ViTModel(config, add_pooling_layer=False).eval() + else: + model = ViTForImageClassification(config).eval() + model.load_state_dict(state_dict) + + # Check outputs on an image, prepared by ViTFeatureExtractor + feature_extractor = ViTFeatureExtractor() + encoding = feature_extractor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + outputs = model(pixel_values) + + if base_model: + final_hidden_state_cls_token = original_model(pixel_values) + assert torch.allclose(final_hidden_state_cls_token, outputs.last_hidden_state[:, 0, :], atol=1e-1) + else: + logits = original_model(pixel_values) + assert logits.shape == outputs.logits.shape + assert torch.allclose(logits, outputs.logits, atol=1e-3) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving feature extractor to {pytorch_dump_folder_path}") + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="dino_vitb16", + type=str, + help="Name of the model trained with DINO you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--base_model", + action="store_true", + help="Whether to only convert the base model (no projection head weights).", + ) + + parser.set_defaults(base_model=True) + args = parser.parse_args() + convert_vit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.base_model) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 308755234e..82b3c59210 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -74,15 +74,55 @@ class ViTEmbeddings(nn.Module): num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config - def forward(self, pixel_values): - batch_size = pixel_values.shape[0] - embeddings = self.patch_embeddings(pixel_values) + def interpolate_pos_encoding(self, embeddings, height, width): + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + npatch = embeddings.shape[1] - 1 + N = self.position_embeddings.shape[1] - 1 + if npatch == N and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-1] and int(w0) == patch_pos_embed.shape[-2] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values, interpolate_pos_encoding=False): + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + # add the [CLS] token to the embedded patch tokens cls_tokens = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) - embeddings = embeddings + self.position_embeddings + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings @@ -105,13 +145,13 @@ class PatchEmbeddings(nn.Module): self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values): + def forward(self, pixel_values, interpolate_pos_encoding=False): batch_size, num_channels, height, width = pixel_values.shape - # FIXME look at relaxing size constraints - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." - ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) x = self.projection(pixel_values).flatten(2).transpose(1, 2) return x @@ -419,6 +459,8 @@ VIT_INPUTS_DOCSTRING = r""" output_hidden_states (:obj:`bool`, `optional`): Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for more detail. + interpolate_pos_encoding (:obj:`bool`, `optional`): + Whether to interpolate the pre-trained position encodings. return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ @@ -460,6 +502,7 @@ class ViTModel(ViTPreTrainedModel): head_mask=None, output_attentions=None, output_hidden_states=None, + interpolate_pos_encoding=None, return_dict=None, ): r""" @@ -497,7 +540,7 @@ class ViTModel(ViTPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings(pixel_values) + embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( embedding_output, @@ -564,6 +607,7 @@ class ViTForImageClassification(ViTPreTrainedModel): labels=None, output_attentions=None, output_hidden_states=None, + interpolate_pos_encoding=None, return_dict=None, ): r""" @@ -600,6 +644,7 @@ class ViTForImageClassification(ViTPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, )