Add DINO conversion script (#13265)
* First commit * Add interpolation of patch embeddings * Comment out code * Fix bug * Fix another bug * Fix bug * Fix another bug * Remove print statements * Update conversion script * Use the official vit implementation * Add support for converting dino_vits8 * Add DINO to docs of ViT * Remove assertion * Add interpolation of position encodings * Fix bug * Add align_corners * Add interpolate_pos_encoding option to forward pass of ViTModel * Improve interpolate_pos_encoding method * Add docstring
This commit is contained in:
@@ -66,6 +66,23 @@ Tips:
|
|||||||
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
|
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
|
||||||
improvement of 2% to training from scratch, but still 4% behind supervised pre-training.
|
improvement of 2% to training from scratch, but still 4% behind supervised pre-training.
|
||||||
|
|
||||||
|
Following the original Vision Transformer, some follow-up works have been made:
|
||||||
|
|
||||||
|
- DeiT (Data-efficient Image Transformers) by Facebook AI. DeiT models are distilled vision transformers. Refer to
|
||||||
|
:doc:`DeiT's documentation page <deit>`. The authors of DeiT also released more efficiently trained ViT models, which
|
||||||
|
you can directly plug into :class:`~transformers.ViTModel` or :class:`~transformers.ViTForImageClassification`. There
|
||||||
|
are 4 variants available (in 3 different sizes): `facebook/deit-tiny-patch16-224`, `facebook/deit-small-patch16-224`,
|
||||||
|
`facebook/deit-base-patch16-224` and `facebook/deit-base-patch16-384`. Note that one should use
|
||||||
|
:class:`~transformers.DeiTFeatureExtractor` in order to prepare images for the model.
|
||||||
|
|
||||||
|
- BEiT (BERT pre-training of Image Transformers) by Microsoft Research. BEiT models outperform supervised pre-trained
|
||||||
|
vision transformers using a self-supervised method inspired by BERT (masked image modeling) and based on a VQ-VAE.
|
||||||
|
Refer to :doc:`BEiT's documentation page <beit>`.
|
||||||
|
|
||||||
|
- DINO (a method for self-supervised training of Vision Transformers) by Facebook AI. Vision Transformers trained using
|
||||||
|
the DINO method show very interesting properties not seen with convolutional models. They are capable of segmenting
|
||||||
|
objects, without having ever been trained to do so. DINO checkpoints can be found on the `hub
|
||||||
|
<https://huggingface.co/models?other=dino>`__.
|
||||||
|
|
||||||
This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The original code (written in JAX) can be
|
This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The original code (written in JAX) can be
|
||||||
found `here <https://github.com/google-research/vision_transformer>`__.
|
found `here <https://github.com/google-research/vision_transformer>`__.
|
||||||
|
|||||||
@@ -93,7 +93,6 @@ class DeiTEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.vit.modeling_vit.PatchEmbeddings
|
|
||||||
class PatchEmbeddings(nn.Module):
|
class PatchEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Image to Patch Embedding.
|
Image to Patch Embedding.
|
||||||
|
|||||||
219
src/transformers/models/vit/convert_dino_to_pytorch.py
Normal file
219
src/transformers/models/vit/convert_dino_to_pytorch.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Convert ViT checkpoints trained with the DINO method."""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
|
from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||||
|
def create_rename_keys(config, base_model=False):
|
||||||
|
rename_keys = []
|
||||||
|
for i in range(config.num_hidden_layers):
|
||||||
|
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||||
|
rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
|
||||||
|
rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
|
||||||
|
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
|
||||||
|
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
|
||||||
|
rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
|
||||||
|
rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
|
||||||
|
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||||
|
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||||
|
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
|
||||||
|
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
|
||||||
|
|
||||||
|
# projection layer + position embeddings
|
||||||
|
rename_keys.extend(
|
||||||
|
[
|
||||||
|
("cls_token", "vit.embeddings.cls_token"),
|
||||||
|
("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
|
||||||
|
("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
|
||||||
|
("pos_embed", "vit.embeddings.position_embeddings"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if base_model:
|
||||||
|
# layernorm + pooler
|
||||||
|
rename_keys.extend(
|
||||||
|
[
|
||||||
|
("norm.weight", "layernorm.weight"),
|
||||||
|
("norm.bias", "layernorm.bias"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# if just the base model, we should remove "vit" from all keys that start with "vit"
|
||||||
|
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
|
||||||
|
else:
|
||||||
|
# layernorm + classification head
|
||||||
|
rename_keys.extend(
|
||||||
|
[
|
||||||
|
("norm.weight", "vit.layernorm.weight"),
|
||||||
|
("norm.bias", "vit.layernorm.bias"),
|
||||||
|
("head.weight", "classifier.weight"),
|
||||||
|
("head.bias", "classifier.bias"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return rename_keys
|
||||||
|
|
||||||
|
|
||||||
|
# we split up the matrix of each encoder layer into queries, keys and values
|
||||||
|
def read_in_q_k_v(state_dict, config, base_model=False):
|
||||||
|
for i in range(config.num_hidden_layers):
|
||||||
|
if base_model:
|
||||||
|
prefix = ""
|
||||||
|
else:
|
||||||
|
prefix = "vit."
|
||||||
|
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
|
||||||
|
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
|
||||||
|
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
|
||||||
|
# next, add query, keys and values (in that order) to the state dict
|
||||||
|
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||||
|
: config.hidden_size, :
|
||||||
|
]
|
||||||
|
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
|
||||||
|
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||||
|
config.hidden_size : config.hidden_size * 2, :
|
||||||
|
]
|
||||||
|
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
|
||||||
|
config.hidden_size : config.hidden_size * 2
|
||||||
|
]
|
||||||
|
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||||
|
-config.hidden_size :, :
|
||||||
|
]
|
||||||
|
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_classification_head_(state_dict):
|
||||||
|
ignore_keys = ["head.weight", "head.bias"]
|
||||||
|
for k in ignore_keys:
|
||||||
|
state_dict.pop(k, None)
|
||||||
|
|
||||||
|
|
||||||
|
def rename_key(dct, old, new):
|
||||||
|
val = dct.pop(old)
|
||||||
|
dct[new] = val
|
||||||
|
|
||||||
|
|
||||||
|
# We will verify our results on an image of cute cats
|
||||||
|
def prepare_img():
|
||||||
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
im = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
return im
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_vit_checkpoint(model_name, pytorch_dump_folder_path, base_model=True):
|
||||||
|
"""
|
||||||
|
Copy/paste/tweak model's weights to our ViT structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# define default ViT configuration
|
||||||
|
config = ViTConfig()
|
||||||
|
# patch_size
|
||||||
|
if model_name[-1] == "8":
|
||||||
|
config.patch_size = 8
|
||||||
|
# set labels if required
|
||||||
|
if not base_model:
|
||||||
|
config.num_labels = 1000
|
||||||
|
repo_id = "datasets/huggingface/label-files"
|
||||||
|
filename = "imagenet-1k-id2label.json"
|
||||||
|
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
|
||||||
|
id2label = {int(k): v for k, v in id2label.items()}
|
||||||
|
config.id2label = id2label
|
||||||
|
config.label2id = {v: k for k, v in id2label.items()}
|
||||||
|
# size of the architecture
|
||||||
|
if model_name in ["dino_vits8", "dino_vits16"]:
|
||||||
|
config.hidden_size = 384
|
||||||
|
config.intermediate_size = 1536
|
||||||
|
config.num_hidden_layers = 12
|
||||||
|
config.num_attention_heads = 6
|
||||||
|
|
||||||
|
# load original model from torch hub
|
||||||
|
original_model = torch.hub.load("facebookresearch/dino:main", model_name)
|
||||||
|
original_model.eval()
|
||||||
|
|
||||||
|
# load state_dict of original model, remove and rename some keys
|
||||||
|
state_dict = original_model.state_dict()
|
||||||
|
if base_model:
|
||||||
|
remove_classification_head_(state_dict)
|
||||||
|
rename_keys = create_rename_keys(config, base_model=base_model)
|
||||||
|
for src, dest in rename_keys:
|
||||||
|
rename_key(state_dict, src, dest)
|
||||||
|
read_in_q_k_v(state_dict, config, base_model)
|
||||||
|
|
||||||
|
# load HuggingFace model
|
||||||
|
if base_model:
|
||||||
|
model = ViTModel(config, add_pooling_layer=False).eval()
|
||||||
|
else:
|
||||||
|
model = ViTForImageClassification(config).eval()
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# Check outputs on an image, prepared by ViTFeatureExtractor
|
||||||
|
feature_extractor = ViTFeatureExtractor()
|
||||||
|
encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
|
||||||
|
pixel_values = encoding["pixel_values"]
|
||||||
|
outputs = model(pixel_values)
|
||||||
|
|
||||||
|
if base_model:
|
||||||
|
final_hidden_state_cls_token = original_model(pixel_values)
|
||||||
|
assert torch.allclose(final_hidden_state_cls_token, outputs.last_hidden_state[:, 0, :], atol=1e-1)
|
||||||
|
else:
|
||||||
|
logits = original_model(pixel_values)
|
||||||
|
assert logits.shape == outputs.logits.shape
|
||||||
|
assert torch.allclose(logits, outputs.logits, atol=1e-3)
|
||||||
|
|
||||||
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||||
|
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
||||||
|
model.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
|
||||||
|
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
default="dino_vitb16",
|
||||||
|
type=str,
|
||||||
|
help="Name of the model trained with DINO you'd like to convert.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_model",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to only convert the base model (no projection head weights).",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.set_defaults(base_model=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_vit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.base_model)
|
||||||
@@ -74,15 +74,55 @@ class ViTEmbeddings(nn.Module):
|
|||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def interpolate_pos_encoding(self, embeddings, height, width):
|
||||||
batch_size = pixel_values.shape[0]
|
"""
|
||||||
embeddings = self.patch_embeddings(pixel_values)
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
|
||||||
|
npatch = embeddings.shape[1] - 1
|
||||||
|
N = self.position_embeddings.shape[1] - 1
|
||||||
|
if npatch == N and height == width:
|
||||||
|
return self.position_embeddings
|
||||||
|
class_pos_embed = self.position_embeddings[:, 0]
|
||||||
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
h0 = height // self.config.patch_size
|
||||||
|
w0 = width // self.config.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
||||||
|
scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
assert int(h0) == patch_pos_embed.shape[-1] and int(w0) == patch_pos_embed.shape[-2]
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||||
|
|
||||||
|
def forward(self, pixel_values, interpolate_pos_encoding=False):
|
||||||
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
|
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
|
|
||||||
|
# add the [CLS] token to the embedded patch tokens
|
||||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||||
|
|
||||||
|
# add positional encoding to each token
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
embeddings = embeddings + self.position_embeddings
|
embeddings = embeddings + self.position_embeddings
|
||||||
|
|
||||||
embeddings = self.dropout(embeddings)
|
embeddings = self.dropout(embeddings)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@@ -105,9 +145,9 @@ class PatchEmbeddings(nn.Module):
|
|||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values, interpolate_pos_encoding=False):
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
# FIXME look at relaxing size constraints
|
if not interpolate_pos_encoding:
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
if height != self.image_size[0] or width != self.image_size[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
@@ -419,6 +459,8 @@ VIT_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (:obj:`bool`, `optional`):
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
|
interpolate_pos_encoding (:obj:`bool`, `optional`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -460,6 +502,7 @@ class ViTModel(ViTPreTrainedModel):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
interpolate_pos_encoding=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -497,7 +540,7 @@ class ViTModel(ViTPreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output = self.embeddings(pixel_values)
|
embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
@@ -564,6 +607,7 @@ class ViTForImageClassification(ViTPreTrainedModel):
|
|||||||
labels=None,
|
labels=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
interpolate_pos_encoding=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -600,6 +644,7 @@ class ViTForImageClassification(ViTPreTrainedModel):
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user