diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 12e4224070..7916dd9a06 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -905,6 +905,8 @@
- sections:
- local: model_doc/timesformer
title: TimeSformer
+ - local: model_doc/vjepa2
+ title: V-JEPA 2
- local: model_doc/videomae
title: VideoMAE
- local: model_doc/vivit
diff --git a/docs/source/en/model_doc/vjepa2.md b/docs/source/en/model_doc/vjepa2.md
new file mode 100644
index 0000000000..5ad02ae274
--- /dev/null
+++ b/docs/source/en/model_doc/vjepa2.md
@@ -0,0 +1,82 @@
+
+
+
+
+
+# V-JEPA 2
+
+V-JEPA 2 is a self-supervised approach to training video encoders developed by FAIR, Meta. Using internet-scale video data, V-JEPA 2 attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration.
+
+
+

+
+
+You can find all original V-JEPA2 checkpoints under the [V-JEPA 2](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) collection.
+
+
+This model was contributed by [koustuvs](https://huggingface.co/koustuvs), [yonigozlan](https://huggingface.co/yonigozlan) and [qubvel](https://huggingface.co/qubvel-hf). The original code can be found [here](https://github.com/facebookresearch/vjepa2).
+
+## Usage example
+
+The snippet below shows how to load the V-JEPA 2 model using the `AutoModel` class.
+
+```py
+import torch
+from torchcodec.decoders import VideoDecoder
+import numpy as np
+
+processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
+model = AutoModel.from_pretrained(
+ "facebook/vjepa2-vitl-fpc64-256",
+ torch_dtype=torch.float16,
+ device_map="auto",
+ attn_implementation="sdpa"
+)
+
+video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"
+
+vr = VideoDecoder(video_url)
+frame_idx = np.arange(0, 64) # choosing some frames. here, you can define more complex sampling strategy
+video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W
+video = processor(video, return_tensors="pt").to(model.device)
+outputs = model(**video)
+
+# V-JEPA 2 encoder outputs, same as calling `model.get_vision_features()`
+encoder_outputs = outputs.last_hidden_state
+
+# V-JEPA 2 predictor outputs
+predictor_outputs = outputs.predictor_output.last_hidden_state
+```
+
+## VJEPA2Config
+
+[[autodoc]] VJEPA2Config
+
+## VJEPA2Model
+
+[[autodoc]] VJEPA2Model
+ - forward
+
+## VJEPA2VideoProcessor
+
+[[autodoc]] VJEPA2VideoProcessor
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index ea10dc8666..dea3c98c38 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -323,6 +323,7 @@ if TYPE_CHECKING:
from .vitpose_backbone import *
from .vits import *
from .vivit import *
+ from .vjepa2 import *
from .wav2vec2 import *
from .wav2vec2_bert import *
from .wav2vec2_conformer import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index da55432dfd..fe8a889b5d 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -365,6 +365,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("vitpose_backbone", "VitPoseBackboneConfig"),
("vits", "VitsConfig"),
("vivit", "VivitConfig"),
+ ("vjepa2", "VJEPA2Config"),
("wav2vec2", "Wav2Vec2Config"),
("wav2vec2-bert", "Wav2Vec2BertConfig"),
("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
@@ -750,6 +751,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("vitpose_backbone", "ViTPoseBackbone"),
("vits", "VITS"),
("vivit", "ViViT"),
+ ("vjepa2", "VJEPA2Model"),
("wav2vec2", "Wav2Vec2"),
("wav2vec2-bert", "Wav2Vec2-BERT"),
("wav2vec2-conformer", "Wav2Vec2-Conformer"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 0998af5592..bcd56483c1 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -336,6 +336,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("vitdet", "VitDetModel"),
("vits", "VitsModel"),
("vivit", "VivitModel"),
+ ("vjepa2", "VJEPA2Model"),
("wav2vec2", "Wav2Vec2Model"),
("wav2vec2-bert", "Wav2Vec2BertModel"),
("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
diff --git a/src/transformers/models/auto/video_processing_auto.py b/src/transformers/models/auto/video_processing_auto.py
index 1688b7fbeb..0f48dcdac7 100644
--- a/src/transformers/models/auto/video_processing_auto.py
+++ b/src/transformers/models/auto/video_processing_auto.py
@@ -56,6 +56,7 @@ else:
("qwen2_vl", "Qwen2VLVideoProcessor"),
("smolvlm", "SmolVLMVideoProcessor"),
("video_llava", "VideoLlavaVideoProcessor"),
+ ("vjepa2", "VJEPA2VideoProcessor"),
]
)
diff --git a/src/transformers/models/vjepa2/__init__.py b/src/transformers/models/vjepa2/__init__.py
new file mode 100644
index 0000000000..f184d058a8
--- /dev/null
+++ b/src/transformers/models/vjepa2/__init__.py
@@ -0,0 +1,29 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_vjepa2 import *
+ from .modeling_vjepa2 import *
+ from .video_processing_vjepa2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/vjepa2/configuration_vjepa2.py b/src/transformers/models/vjepa2/configuration_vjepa2.py
new file mode 100644
index 0000000000..4571b88602
--- /dev/null
+++ b/src/transformers/models/vjepa2/configuration_vjepa2.py
@@ -0,0 +1,146 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""VJEPA 2 model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+
+
+class VJEPA2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`VJEPA2Model`]. It is used to instantiate an
+ VJEPA2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the VJEPA2
+ [facebook/vjepa2-vitl-fpc64-256](https://huggingface.co/facebook/vjepa2-vitl-fpc64-256) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ crop_size (`int`, *optional*, defaults to 256):
+ Input resolution of the model
+ frames_per_clip (`int`, *optional*, defaults to 64):
+ The number of frames the model has been pretrained with. Does not impact inference.
+ tubelet_size (`int`, *optional*, defaults to 2):
+ The number of temporal frames used for a single rastor, check paper for more information.
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the encoder layers
+ in_chans (`int`, *optional*, defaults to 3):
+ The number of input channels
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Encoder
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ The number of hidden layers
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Ratio of the hidden size of the MLPs used in Encoder relative to the `hidden_size`.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for attentions.
+ The dropout probability for all fully connected layers.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ pred_hidden_size (`int`, *optional*, defaults to 384):
+ Dimensionality of the predictor layers
+ pred_num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Predictor
+ pred_num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Predictor
+ pred_num_mask_tokens (`int`, *optional*, defaults to 10):
+ Define the number of mask tokens to use in the Predictor
+ pred_zero_init_mask_tokens (`bool`, *optional*, defaults to `True`):
+ Initialize the mask tokens in the predictor with 0.
+ pred_mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Ratio of the hidden size of the MLPs used in Predictor relative to the `pred_hidden_size`.
+
+ Example:
+
+ ```python
+ >>> from transformers import VJEPA2Config, VJEPA2Model
+
+ >>> # Initializing a VJEPA2 vjepa2-vitl-fpc64-256 style configuration
+ >>> configuration = VJEPA2Config()
+
+ >>> # Initializing a model (with random weights) from the vjepa2-vitl-fpc64-256 style configuration
+ >>> model = VJEPA2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "vjepa2"
+
+ def __init__(
+ self,
+ patch_size=16,
+ crop_size=256,
+ frames_per_clip=64,
+ tubelet_size=2,
+ hidden_size=1024,
+ in_chans=3,
+ num_attention_heads=16,
+ num_hidden_layers=24,
+ drop_path_rate=0.0,
+ mlp_ratio=4.0,
+ layer_norm_eps=1e-6,
+ qkv_bias=True,
+ attention_probs_dropout_prob=0.0,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ # predictor params
+ pred_hidden_size=384,
+ pred_num_attention_heads=12,
+ pred_num_hidden_layers=12,
+ pred_num_mask_tokens=10,
+ pred_zero_init_mask_tokens=True,
+ pred_mlp_ratio=4.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.crop_size = crop_size
+ self.frames_per_clip = frames_per_clip
+ self.patch_size = patch_size
+ self.tubelet_size = tubelet_size
+ self.hidden_size = hidden_size
+ self.in_chans = in_chans
+ self.num_attention_heads = num_attention_heads
+ self.num_hidden_layers = num_hidden_layers
+ self.drop_path_rate = drop_path_rate
+ self.mlp_ratio = mlp_ratio
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.image_size = crop_size
+ # predictor params
+ self.pred_hidden_size = pred_hidden_size
+ self.pred_num_attention_heads = pred_num_attention_heads
+ self.pred_num_hidden_layers = pred_num_hidden_layers
+ self.pred_num_mask_tokens = pred_num_mask_tokens
+ self.pred_zero_init_mask_tokens = pred_zero_init_mask_tokens
+ self.pred_mlp_ratio = pred_mlp_ratio
+
+
+__all__ = ["VJEPA2Config"]
diff --git a/src/transformers/models/vjepa2/convert_vjepa2_to_hf.py b/src/transformers/models/vjepa2/convert_vjepa2_to_hf.py
new file mode 100644
index 0000000000..527dbc35d9
--- /dev/null
+++ b/src/transformers/models/vjepa2/convert_vjepa2_to_hf.py
@@ -0,0 +1,346 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import tempfile
+from pathlib import Path
+
+import numpy as np
+import requests
+import torch
+from huggingface_hub import HfApi
+from PIL import Image
+
+from transformers import VJEPA2Config, VJEPA2Model, VJEPA2VideoProcessor
+from transformers.models.vjepa2.modeling_vjepa2 import apply_masks
+
+
+HUB_REPO = "https://github.com/facebookresearch/vjepa2"
+HUB_SOURCE = "github"
+
+HUB_MODELS = {
+ "vit_large": "facebook/vjepa2-vitl-fpc64-256",
+ "vit_huge": "facebook/vjepa2-vith-fpc64-256",
+ "vit_giant": "facebook/vjepa2-vitg-fpc64-256",
+ "vit_giant_384": "facebook/vjepa2-vitg-fpc64-384",
+}
+
+S3_MODELS = {
+ "vit_large": "https://dl.fbaipublicfiles.com/vjepa2/vitl.pt",
+ "vit_huge": "https://dl.fbaipublicfiles.com/vjepa2/vith.pt",
+ "vit_giant": "https://dl.fbaipublicfiles.com/vjepa2/vitg.pt",
+ "vit_giant_384": "https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt",
+}
+
+TOKEN = os.environ.get("HF_TOKEN", None)
+
+
+def get_vjepa2_config(model_name):
+ # size of the architecture
+ if model_name == "vit_large":
+ return VJEPA2Config(
+ crop_size=256,
+ frames_per_clip=64,
+ hidden_size=1024,
+ num_attention_heads=16,
+ num_hidden_layers=24,
+ mlp_ratio=4,
+ pred_hidden_size=384,
+ pred_num_attention_heads=12,
+ pred_num_hidden_layers=12,
+ pred_num_mask_tokens=10,
+ )
+ elif model_name == "vit_huge":
+ return VJEPA2Config(
+ crop_size=256,
+ frames_per_clip=64,
+ hidden_size=1280,
+ num_attention_heads=16,
+ num_hidden_layers=32,
+ mlp_ratio=4,
+ pred_hidden_size=384,
+ pred_num_attention_heads=12,
+ pred_num_hidden_layers=12,
+ pred_num_mask_tokens=10,
+ )
+ elif model_name == "vit_giant":
+ return VJEPA2Config(
+ crop_size=256,
+ frames_per_clip=64,
+ hidden_size=1408,
+ num_attention_heads=22,
+ num_hidden_layers=40,
+ mlp_ratio=48 / 11,
+ pred_hidden_size=384,
+ pred_num_attention_heads=12,
+ pred_num_hidden_layers=12,
+ pred_num_mask_tokens=10,
+ )
+ elif model_name == "vit_giant_384":
+ return VJEPA2Config(
+ crop_size=384,
+ frames_per_clip=64,
+ hidden_size=1408,
+ num_attention_heads=22,
+ num_hidden_layers=40,
+ mlp_ratio=48 / 11,
+ pred_hidden_size=384,
+ pred_num_attention_heads=12,
+ pred_num_hidden_layers=12,
+ pred_num_mask_tokens=10,
+ )
+ else:
+ raise ValueError("Model not supported")
+
+
+def convert_encoder_keys(model_state_dict, og_encoder_state_dict, config):
+ emb_dim = config.hidden_size
+ for key, val in og_encoder_state_dict.copy().items():
+ val = og_encoder_state_dict.pop(key)
+ key = key.replace("module.backbone.", "")
+ if key.startswith("blocks."):
+ key = key.replace("blocks.", "encoder.layer.")
+ if "attn." in key:
+ key = key.replace("attn.", "attention.")
+ if key == "pos_embed":
+ key = "encoder.embeddings.position_embeddings"
+ if "patch_embed." in key:
+ key = key.replace("patch_embed.", "encoder.embeddings.patch_embeddings.")
+ if key.startswith("norm."):
+ key = key.replace("norm.", "encoder.layernorm.")
+ if "qkv." in key:
+ prefix, suffix = key.split("qkv")
+ if "bias" in suffix:
+ q_e, k_e, v_e = (
+ val[0:emb_dim],
+ val[emb_dim : emb_dim * 2],
+ val[emb_dim * 2 :],
+ )
+ else:
+ q_e, k_e, v_e = (
+ val[0:emb_dim, :],
+ val[emb_dim : emb_dim * 2, :],
+ val[emb_dim * 2 :, :],
+ )
+ og_encoder_state_dict[prefix + "query" + suffix] = q_e
+ og_encoder_state_dict[prefix + "key" + suffix] = k_e
+ og_encoder_state_dict[prefix + "value" + suffix] = v_e
+ else:
+ og_encoder_state_dict[key] = val
+ return og_encoder_state_dict
+
+
+def convert_predictor_keys(model_state_dict, og_predictor_state_dict, config):
+ emb_dim = config.pred_hidden_size
+ if "predictor_pos_embed" in og_predictor_state_dict:
+ del og_predictor_state_dict["predictor_pos_embed"]
+ # update predictor weights
+ mask_tokens = {}
+ mask_token_keys_to_delete = []
+ for key, val in og_predictor_state_dict.copy().items():
+ val = og_predictor_state_dict.pop(key)
+ key = key.replace("module.backbone.", "")
+ if key.startswith("predictor_blocks."):
+ key = key.replace("predictor_blocks.", "predictor.layer.")
+ if "attn." in key:
+ key = key.replace("attn.", "attention.")
+ if key == "predictor_pos_embed":
+ key = "predictor.embeddings.position_embeddings"
+ if "predictor_embed." in key:
+ key = key.replace("predictor_embed.", "predictor.embeddings.predictor_embeddings.")
+ if "mask_tokens." in key:
+ mask_tokens[key.split("mask_tokens.")[-1]] = val
+ mask_token_keys_to_delete.append(key)
+ # key = key.replace("mask_tokens.", "predictor.embeddings.mask_tokens.")
+ if key.startswith("predictor_norm."):
+ key = key.replace("predictor_norm.", "predictor.layernorm.")
+ if key.startswith("predictor_proj."):
+ key = key.replace("predictor_proj.", "predictor.proj.")
+ if "qkv." in key:
+ prefix, suffix = key.split("qkv")
+ if "bias" in suffix:
+ q_e, k_e, v_e = (
+ val[0:emb_dim],
+ val[emb_dim : emb_dim * 2],
+ val[emb_dim * 2 :],
+ )
+ else:
+ q_e, k_e, v_e = (
+ val[0:emb_dim, :],
+ val[emb_dim : emb_dim * 2, :],
+ val[emb_dim * 2 :, :],
+ )
+ og_predictor_state_dict[prefix + "query" + suffix] = q_e
+ og_predictor_state_dict[prefix + "key" + suffix] = k_e
+ og_predictor_state_dict[prefix + "value" + suffix] = v_e
+ else:
+ og_predictor_state_dict[key] = val
+ mask_tokens = torch.stack([mask_tokens[f"{i}"] for i in range(len(mask_tokens))], dim=0)
+ for k in mask_token_keys_to_delete:
+ del og_predictor_state_dict[k]
+ og_predictor_state_dict["predictor.embeddings.mask_tokens"] = mask_tokens
+ return og_predictor_state_dict
+
+
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
+ return image
+
+
+def upload_original_ckpts(model_name):
+ hf_repo = HUB_MODELS[model_name]
+ original_ckpt = S3_MODELS[model_name]
+ print(f"Uploading original checkpoint for vjepa2 {model_name} to {hf_repo}/original/")
+ with tempfile.NamedTemporaryFile() as fn:
+ local_path = fn.name
+ torch.hub.download_url_to_file(original_ckpt, local_path)
+ api = HfApi()
+ api.upload_file(
+ repo_id=hf_repo,
+ path_or_fileobj=local_path,
+ path_in_repo="original/model.pth",
+ repo_type="model",
+ token=TOKEN,
+ )
+ print("Uploading complete")
+
+
+@torch.no_grad()
+def convert_and_test_vjepa2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
+ """
+ Copy/paste/tweak model's weights to our VJEPA2 structure.
+ """
+ config = get_vjepa2_config(model_name)
+
+ # load original model from torch hub
+ original_encoder, original_predictor = torch.hub.load(HUB_REPO, "vjepa2_" + model_name, source=HUB_SOURCE)
+ original_encoder.eval()
+ original_predictor.eval()
+ original_preprocessor = torch.hub.load(
+ HUB_REPO, "vjepa2_preprocessor", source=HUB_SOURCE, crop_size=config.crop_size
+ )
+
+ # load state_dict of original model, remove and rename some keys
+ encoder_state_dict = original_encoder.state_dict()
+ decoder_state_dict = original_predictor.state_dict()
+
+ model = VJEPA2Model(config).eval()
+ state_dict = model.state_dict()
+
+ og_encoder_sd = convert_encoder_keys(state_dict, encoder_state_dict, config)
+ og_predictor_sd = convert_predictor_keys(state_dict, decoder_state_dict, config)
+
+ og_state_dict = og_encoder_sd
+ og_state_dict.update(og_predictor_sd)
+ model.load_state_dict(og_state_dict)
+
+ # load image
+ image = prepare_img()
+ image = torch.Tensor(np.array(image)).unsqueeze(0).permute(0, 3, 1, 2)
+ print("Input shape: ", image.shape)
+
+ crop_size = config.crop_size
+ processor = VJEPA2VideoProcessor(crop_size=crop_size)
+ pr_out = processor(image, return_tensors="pt")
+ pixel_values_videos = pr_out.pixel_values_videos
+ # run original preprocessor
+ original_pixel_values = original_preprocessor(image)
+ assert original_pixel_values[0].permute(1, 0, 2, 3).shape == pixel_values_videos[0].shape
+ assert torch.allclose(original_pixel_values[0].permute(1, 0, 2, 3), pixel_values_videos[0], atol=1e-3)
+
+ with torch.no_grad():
+ # reshape and move to gpu
+ if pixel_values_videos.size(1) == 1:
+ pixel_values_videos = pixel_values_videos.repeat(1, config.frames_per_clip, 1, 1, 1)
+ # pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4) # B x C x T x H x W
+ pixel_values_videos = pixel_values_videos.to(device="cuda", dtype=torch.float32)
+ original_encoder = original_encoder.to(device="cuda", dtype=torch.float32)
+ original_predictor = original_predictor.to(device="cuda", dtype=torch.float32)
+ model = model.to(device="cuda", dtype=torch.float32)
+ # forward
+ original_encoder_outputs = original_encoder(pixel_values_videos.permute(0, 2, 1, 3, 4))
+ B, N, _ = original_encoder_outputs.shape
+ # test full mask
+ context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
+ predictor_mask = context_mask
+ original_predictor_outputs = original_predictor(original_encoder_outputs, context_mask, predictor_mask)
+ outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
+ assert torch.allclose(outputs.last_hidden_state, original_encoder_outputs, atol=1e-3)
+ predictor_outputs = outputs.predictor_output
+ assert torch.allclose(predictor_outputs.last_hidden_state, original_predictor_outputs, atol=1e-3)
+ # test partial mask
+ window_size = 256
+ mask = torch.arange(N, device=pixel_values_videos.device).unsqueeze(0)
+ context_mask = [mask[:, :window_size].repeat((B, 1))]
+ predictor_mask = [mask[:, window_size : window_size * 2].repeat((B, 1))]
+ original_predictor_outputs = original_predictor(
+ apply_masks(original_encoder_outputs, context_mask),
+ context_mask,
+ predictor_mask,
+ )
+ outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
+ assert torch.allclose(outputs.last_hidden_state, original_encoder_outputs, atol=1e-3)
+ predictor_outputs = outputs.predictor_output
+ assert torch.allclose(predictor_outputs.last_hidden_state, original_predictor_outputs, atol=1e-3)
+
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving image processor to {pytorch_dump_folder_path}")
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ name = HUB_MODELS[model_name]
+ model.push_to_hub(name, private=True)
+ processor.push_to_hub(name, private=True)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="vit_large",
+ type=str,
+ choices=[
+ "vit_large",
+ "vit_huge",
+ "vit_giant",
+ "vit_giant_384",
+ ],
+ help="Name of the model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether or not to push the converted model to the 🤗 hub.",
+ )
+ parser.add_argument("--upload_original", action="store_true", help="upload the original checkpoint")
+
+ args = parser.parse_args()
+ convert_and_test_vjepa2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
+ if args.upload_original:
+ upload_original_ckpts(args.model_name)
diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py
new file mode 100644
index 0000000000..7a3a95b129
--- /dev/null
+++ b/src/transformers/models/vjepa2/modeling_vjepa2.py
@@ -0,0 +1,903 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, dataclass
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
+from .configuration_vjepa2 import VJEPA2Config
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class VJEPA2WithMaskedInputPredictorOutput(ModelOutput):
+ """
+ VJEPA Predictor outputs that also contains the masked encoder outputs
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ masked_hidden_state (`torch.FloatTensor`), *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ target_hidden_state (`torch.FloatTensor`), *optional*):
+ Returned when `target_mask` is provided which is applied on VJEPA2Encoder outputs.
+ """
+
+ last_hidden_state: torch.FloatTensor
+ masked_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ target_hidden_state: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class VJEPA2WithMaskedInputModelOutput(ModelOutput):
+ """
+ VJEPA outputs that also contains the masked encoder outputs
+ Optionally contains the predictor outputs
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ masked_hidden_state (`torch.FloatTensor`), *optional*):
+ Returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ predictor_output (`VJEPA2WithMaskedInputPredictorOutput`, *optional*):
+ Returns the output from the Predictor module
+ """
+
+ last_hidden_state: torch.FloatTensor
+ masked_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ predictor_output: Optional[VJEPA2WithMaskedInputPredictorOutput] = None
+
+ def to_tuple(self):
+ output = list(super().to_tuple())
+ if isinstance(output[-1], VJEPA2WithMaskedInputPredictorOutput):
+ output[-1] = output[-1].to_tuple()
+ return tuple(output)
+
+
+class VJEPA2PatchEmbeddings3D(nn.Module):
+ """
+ Image to Patch Embedding
+ """
+
+ def __init__(
+ self,
+ config: VJEPA2Config,
+ hidden_size: int = 1024,
+ ):
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.tubelet_size = config.tubelet_size
+ self.hidden_size = hidden_size
+
+ self.proj = nn.Conv3d(
+ in_channels=config.in_chans,
+ out_channels=hidden_size,
+ kernel_size=(config.tubelet_size, config.patch_size, config.patch_size),
+ stride=(config.tubelet_size, config.patch_size, config.patch_size),
+ )
+
+ @staticmethod
+ def num_patches(config):
+ return (
+ (config.frames_per_clip // config.tubelet_size)
+ * (config.crop_size // config.patch_size)
+ * (config.crop_size // config.patch_size)
+ )
+
+ def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
+ x = self.proj(pixel_values_videos).flatten(2).transpose(1, 2)
+ return x
+
+
+class VJEPA2Embeddings(nn.Module):
+ """
+ Construct mask token, position and patch embeddings.
+ """
+
+ def __init__(self, config: VJEPA2Config, hidden_size: int = 1024):
+ super().__init__()
+
+ self.config = config
+ self.hidden_size = hidden_size
+ self.patch_embeddings = VJEPA2PatchEmbeddings3D(config, hidden_size=hidden_size)
+
+ self.num_patches = self.patch_embeddings.num_patches
+ self.patch_size = config.patch_size
+
+ def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
+ num_frames = pixel_values_videos.shape[1]
+
+ # Swap `frames` and `channels` dims, the result is:
+ # (batch_size, channels, num_frames, height, width)
+ pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4)
+
+ # For some cases, if the input vision (image/video) consists of num_frames < tubelet_size,
+ # then embedding lookup fails. In these cases, we duplicate the frames.
+ if num_frames < self.config.tubelet_size:
+ pixel_values_videos = pixel_values_videos.repeat(1, 1, self.config.tubelet_size, 1, 1)
+
+ target_dtype = self.patch_embeddings.proj.weight.dtype
+ pixel_values_videos = pixel_values_videos.to(dtype=target_dtype)
+ embeddings = self.patch_embeddings(pixel_values_videos)
+
+ return embeddings
+
+
+# Adapted from transformers.models.vit.modeling_vit.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+ # Normalize the attention scores to probabilities.
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ # Mask heads if we want to
+ if attention_mask is not None:
+ attn_weights = attn_weights * attention_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def rotate_queries_or_keys(x, pos):
+ B, num_heads, N, D = x.size()
+
+ # similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ # they are computing this every time. instead HF style is to compute the inv_freq once and store it
+ # -- compute angle for each position
+ omega = torch.arange(D // 2, dtype=x.dtype, device=x.device)
+ omega /= D / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+ freq = torch.einsum("..., f -> ... f", pos, omega) # (..., N, D/2), outer product
+
+ # -- build rotation matrix and apply
+ emb_sin = freq.sin() # (..., N, D/2)
+ emb_cos = freq.cos() # (..., N, D/2)
+
+ emb_sin = emb_sin.squeeze(-1).repeat(1, 1, 1, 2)
+ emb_cos = emb_cos.squeeze(-1).repeat(1, 1, 1, 2)
+
+ # --
+ y = x.unflatten(-1, (-1, 2))
+ y1, y2 = y.unbind(dim=-1)
+
+ y = torch.stack((-y2, y1), dim=-1)
+ y = y.flatten(-2)
+ return (x * emb_cos) + (y * emb_sin)
+
+
+class VJEPA2RopeAttention(nn.Module):
+ def __init__(
+ self,
+ config: VJEPA2Config,
+ hidden_size: int = 1024,
+ num_attention_heads: int = 16,
+ ):
+ super().__init__()
+ self.config = config
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ if hidden_size % num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size {(hidden_size,)} is not a multiple of the number of attention "
+ f"heads {num_attention_heads}."
+ )
+
+ self.attention_head_size = int(hidden_size / num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.proj = nn.Linear(hidden_size, hidden_size)
+ self.dropout_prob = config.attention_probs_dropout_prob
+ self.dropout = nn.Dropout(self.dropout_prob)
+
+ self.grid_size = self.config.crop_size // self.config.patch_size
+ self.grid_depth = self.config.frames_per_clip // self.config.tubelet_size
+
+ self.d_dim = int(2 * ((self.attention_head_size // 3) // 2))
+ self.h_dim = int(2 * ((self.attention_head_size // 3) // 2))
+ self.w_dim = int(2 * ((self.attention_head_size // 3) // 2))
+
+ self.scaling = self.attention_head_size**-0.5
+ self.is_causal = False
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def _get_frame_pos(self, ids):
+ tokens_per_frame = int(self.grid_size * self.grid_size)
+ return ids // tokens_per_frame
+
+ def _get_height_pos(self, ids):
+ # Remove frame component from ids
+ tokens_per_frame = int(self.grid_size * self.grid_size)
+ frame_ids = self._get_frame_pos(ids)
+ ids = ids - tokens_per_frame * frame_ids
+ # --
+ tokens_per_row = self.grid_size
+ return ids // tokens_per_row
+
+ def get_position_ids(self, x, masks=None):
+ device = x.device
+ token_size = x.size(1)
+
+ # Note: when masks is none, we use a 1d id instead of Bxnum_attention_heads mask,
+ # as 1d vector is broadcasted to the correct shapes.
+ if masks is not None:
+ ids = masks.unsqueeze(1).repeat(1, self.num_attention_heads, 1)
+ else:
+ ids = torch.arange(token_size, device=device)
+ # change to allow for extrapolation
+ tokens_per_frame = int(self.grid_size * self.grid_size)
+ frame_ids = self._get_frame_pos(ids)
+ # --
+ tokens_per_row = self.grid_size
+ height_ids = self._get_height_pos(ids)
+ # --
+ # Remove frame component from ids (1st term) and height component (2nd term)
+ width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
+ return frame_ids, height_ids, width_ids
+
+ def apply_rotary_embeddings(self, qk, pos_ids):
+ d_mask, h_mask, w_mask = pos_ids
+ s = 0
+ qkd = rotate_queries_or_keys(qk[..., s : s + self.d_dim], pos=d_mask)
+ s += self.d_dim
+ qkh = rotate_queries_or_keys(qk[..., s : s + self.h_dim], pos=h_mask)
+ s += self.h_dim
+ qkw = rotate_queries_or_keys(qk[..., s : s + self.w_dim], pos=w_mask)
+ s += self.w_dim
+ # Combine rotated dimension
+ if s < self.attention_head_size:
+ qkr = qk[..., s:]
+ qk = torch.cat([qkd, qkh, qkw, qkr], dim=-1)
+ else:
+ qk = torch.cat([qkd, qkh, qkw], dim=-1)
+ return qk
+
+ def forward(
+ self,
+ hidden_states,
+ position_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ head_mask: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ pos_ids = self.get_position_ids(hidden_states, masks=position_mask)
+ key_layer = self.apply_rotary_embeddings(key_layer, pos_ids)
+ query_layer = self.apply_rotary_embeddings(query_layer, pos_ids)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and output_attentions:
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ context_layer, attention_probs = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.dropout_prob,
+ )
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = self.proj(context_layer.reshape(new_context_layer_shape))
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Adapted from transformers.models.beit.modeling_dinov2.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Adapted from transformers.models.beit.modeling_beit.BeitDropPath
+class VJEPA2DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None):
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class VJEPA2MLP(nn.Module):
+ def __init__(self, config: VJEPA2Config, hidden_size: int = 1024, mlp_ratio: float = 4.0):
+ super().__init__()
+ in_features = out_features = hidden_size
+ hidden_features = int(hidden_size * mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+ self.activation = ACT2FN[config.hidden_act]
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
+
+
+class VJEPA2Layer(GradientCheckpointingLayer):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(
+ self,
+ config: VJEPA2Config,
+ drop_path_rate: float = 0.0,
+ hidden_size: int = 1024,
+ num_attention_heads: int = 16,
+ mlp_ratio: float = 4.0,
+ ):
+ super().__init__()
+ self.config = config
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+
+ self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+ self.attention = VJEPA2RopeAttention(config, hidden_size, num_attention_heads)
+ self.drop_path = VJEPA2DropPath(drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ self.norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+ self.mlp = VJEPA2MLP(config, hidden_size=hidden_size, mlp_ratio=mlp_ratio)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, ...]:
+ # Self-Attention
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ self_attention_outputs = self.attention(
+ hidden_states,
+ position_mask=position_mask, # position mask for context/target selection
+ head_mask=head_mask, # head mask is applied at F.scaled_dot_product_attention
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ hidden_states = self.drop_path(attention_output) + residual
+
+ # MLP
+ residual = hidden_states
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.drop_path(hidden_states) + residual
+
+ # Add self attentions if we output attention weights
+ outputs = self_attention_outputs[1:]
+ outputs = (hidden_states,) + outputs
+
+ return outputs
+
+
+class VJEPA2Encoder(nn.Module):
+ def __init__(self, config: VJEPA2Config):
+ super().__init__()
+ self.config = config
+
+ self.embeddings = VJEPA2Embeddings(config, hidden_size=config.hidden_size)
+ drop_path_rates = [
+ (config.drop_path_rate * i / (config.num_hidden_layers - 1) if config.num_hidden_layers > 1 else 0.0)
+ for i in range(config.num_hidden_layers)
+ ]
+ self.layer = nn.ModuleList(
+ [
+ VJEPA2Layer(
+ config,
+ drop_path_rate=drop_path_rates[i],
+ hidden_size=config.hidden_size,
+ num_attention_heads=config.num_attention_heads,
+ mlp_ratio=config.mlp_ratio,
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.gradient_checkpointing = False
+
+ @can_return_tuple
+ def forward(
+ self,
+ pixel_values_videos: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ) -> BaseModelOutput:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ hidden_states = self.embeddings(pixel_values_videos)
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ layer_outputs = layer_module(hidden_states, None, layer_head_mask, output_attentions)
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layernorm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+def apply_masks(x, masks) -> torch.Tensor:
+ """
+ :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
+ :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
+ """
+ all_x = []
+ for m in masks:
+ mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
+ all_x += [torch.gather(x, dim=1, index=mask_keep)]
+
+ return torch.cat(all_x, dim=0)
+
+
+class VJEPA2PredictorEmbeddings(nn.Module):
+ """
+ Construct mask token, position and patch embeddings.
+ """
+
+ def __init__(self, config: VJEPA2Config):
+ super().__init__()
+
+ self.config = config
+ self.predictor_embeddings = nn.Linear(config.hidden_size, config.pred_hidden_size)
+ self.num_mask_tokens = 0
+ self.zero_init_mask_tokens = config.pred_zero_init_mask_tokens
+ self.num_mask_tokens = config.pred_num_mask_tokens
+ self.mask_tokens = nn.Parameter(torch.zeros(self.num_mask_tokens, 1, 1, config.pred_hidden_size))
+
+ self.patch_size = config.patch_size
+ self.config = config
+
+ @staticmethod
+ def num_patches(config):
+ if config.frames_per_clip > 1:
+ return (
+ (config.frames_per_clip // config.tubelet_size)
+ * (config.crop_size // config.patch_size)
+ * (config.crop_size // config.patch_size)
+ )
+ else:
+ return (config.crop_size // config.patch_size) * (config.crop_size // config.patch_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ context_mask: List[torch.Tensor],
+ target_mask: List[torch.Tensor],
+ mask_index: int = 1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ hidden_states : encoder outputs (context)
+ context_mask: tokens of the context (outputs from the encoder)
+ target_mask: tokens to predict
+ mask_index: index of the target mask to choose (useful for multiclip?)
+ """
+
+ B = hidden_states.size(0)
+ context = self.predictor_embeddings(hidden_states)
+
+ # Make target tokens
+ mask_index = mask_index % self.num_mask_tokens
+ target = self.mask_tokens[mask_index]
+
+ # Note: this is problematic if the config isn't initialized with the right frames_per_clip value,
+ # e.g. for scenarios if we want to run predictor for more tokens than in the config.
+ # target = target.repeat(B, self.num_patches(self.config), 1)
+ # Remedy: use the provided target mask to get the max patch num
+ max_patch_num = target_mask[0].max() + 1 # one extra to include the last patch
+ target = target.repeat(B, max_patch_num, 1)
+ target = apply_masks(target, target_mask)
+
+ # Concatenate context & target tokens
+ context = context.repeat(len(context_mask), 1, 1)
+ embeddings = torch.cat([context, target], dim=1)
+
+ # Positions of context & target tokens
+ cm = torch.cat(context_mask, dim=0)
+ tm = torch.cat(target_mask, dim=0)
+ masks = torch.cat([cm, tm], dim=1)
+
+ return embeddings, masks
+
+
+class VJEPA2Predictor(nn.Module):
+ def __init__(self, config: VJEPA2Config):
+ super().__init__()
+ self.config = config
+ self.gradient_checkpointing = False
+ self.embeddings = VJEPA2PredictorEmbeddings(config)
+ drop_path_rates = [
+ (
+ config.drop_path_rate * i / (config.pred_num_hidden_layers - 1)
+ if config.pred_num_hidden_layers > 1
+ else 0.0
+ )
+ for i in range(config.pred_num_hidden_layers)
+ ]
+ self.layer = nn.ModuleList(
+ [
+ VJEPA2Layer(
+ config,
+ drop_path_rate=drop_path_rates[i],
+ hidden_size=config.pred_hidden_size,
+ num_attention_heads=config.pred_num_attention_heads,
+ mlp_ratio=config.pred_mlp_ratio,
+ )
+ for i in range(config.pred_num_hidden_layers)
+ ]
+ )
+ self.layernorm = nn.LayerNorm(config.pred_hidden_size, eps=config.layer_norm_eps)
+ self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True)
+
+ def sort_tokens(self, hidden_states, position_masks, argsort, head_mask=None):
+ position_masks = torch.gather(position_masks, dim=1, index=argsort)
+ hidden_states = torch.gather(
+ hidden_states,
+ dim=1,
+ index=argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
+ )
+ if head_mask is not None and head_mask[0] is not None:
+ head_mask = head_mask.permute(1, 0, 2, 3, 4)
+ argsort_4d = (
+ argsort.unsqueeze(1)
+ .unsqueeze(1)
+ .expand(-1, head_mask.size(1), head_mask.size(2), -1)
+ .unsqueeze(-1)
+ .expand(-1, -1, -1, -1, head_mask.size(-1))
+ )
+ head_mask = torch.gather(head_mask, dim=3, index=argsort_4d)
+ argsort_5d = (
+ argsort.unsqueeze(1)
+ .unsqueeze(1)
+ .unsqueeze(1)
+ .expand(-1, head_mask.size(1), head_mask.size(2), head_mask.size(3), -1)
+ )
+ head_mask = torch.gather(head_mask, dim=4, index=argsort_5d)
+ head_mask = head_mask.permute(1, 0, 2, 3, 4)
+ return hidden_states, position_masks, head_mask
+
+ def unsort_tokens(self, hidden_states, argsort):
+ reverse_argsort = torch.argsort(argsort, dim=1)
+ hidden_states = torch.gather(
+ hidden_states,
+ dim=1,
+ index=reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
+ )
+ return hidden_states
+
+ @can_return_tuple
+ def forward(
+ self,
+ encoder_hidden_states: torch.Tensor,
+ context_mask: List[torch.Tensor],
+ target_mask: List[torch.Tensor],
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ) -> BaseModelOutput:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ # mask out the encoder hidden states
+ # this is implemented here as in VJEPA training a separate encoder is used for target
+ encoder_hidden_states = apply_masks(encoder_hidden_states, context_mask)
+ _, N_ctxt, D = encoder_hidden_states.shape
+ hidden_states, position_masks = self.embeddings(encoder_hidden_states, context_mask, target_mask)
+
+ # Put tokens in sorted order
+ argsort = torch.argsort(position_masks, dim=1) # [B, N]
+ hidden_states, position_masks, head_mask = self.sort_tokens(hidden_states, position_masks, argsort, head_mask)
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ layer_outputs = layer_module(hidden_states, position_masks, layer_head_mask, output_attentions)
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = self.layernorm(hidden_states)
+ # unsort and extract the predicted tokens
+ hidden_states = self.unsort_tokens(hidden_states, argsort)
+ hidden_states = hidden_states[:, N_ctxt:]
+ # projection
+ hidden_states = self.proj(hidden_states)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@auto_docstring
+class VJEPA2PreTrainedModel(PreTrainedModel):
+ config_class = VJEPA2Config
+ base_model_prefix = "vjepa2"
+ main_input_name = "pixel_values_videos"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["VJEPA2Layer"]
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+
+ def _init_weights(
+ self,
+ module: Union[
+ nn.Linear,
+ nn.Conv2d,
+ nn.LayerNorm,
+ VJEPA2Embeddings,
+ VJEPA2PredictorEmbeddings,
+ ],
+ ):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, VJEPA2PredictorEmbeddings):
+ if not module.zero_init_mask_tokens:
+ module.mask_token = nn.init.trunc_normal_(
+ module.mask_token.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.mask_token.dtype)
+ else:
+ module.mask_tokens.data.zero_()
+
+
+def _convert_head_mask_to_5d(head_mask, num_hidden_layers):
+ """
+ Inputs:
+ - head_mask: bsz x seq_length x seq_length | None
+ Returns
+ - [num_hidden_layers x batch x num_heads x seq_length x seq_length] | [num_hidden_layers]
+ """
+ if head_mask is not None:
+ head_mask = head_mask.unsqueeze(1).unsqueeze(0)
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
+ else:
+ head_mask = [None] * num_hidden_layers
+ return head_mask
+
+
+@auto_docstring
+class VJEPA2Model(VJEPA2PreTrainedModel):
+ def __init__(self, config: VJEPA2Config):
+ super().__init__(config)
+ self.config = config
+
+ self.encoder = VJEPA2Encoder(config)
+ self.predictor = VJEPA2Predictor(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> VJEPA2PatchEmbeddings3D:
+ return self.encoder.embeddings.patch_embeddings
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values_videos: torch.Tensor,
+ context_head_mask: Optional[torch.Tensor] = None,
+ context_mask: Optional[List[torch.Tensor]] = None,
+ target_head_mask: Optional[torch.Tensor] = None,
+ target_mask: Optional[List[torch.Tensor]] = None,
+ skip_predictor: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> VJEPA2WithMaskedInputModelOutput:
+ r"""
+ pixel_values_videos (`torch.Tensor` with shape `[batch size x num_frames x num_channels x height x width]`, required):
+ The input video pixels which is processed by VJEPA2VideoProcessor.
+ context_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
+ The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the context.
+ target_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
+ The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the target.
+ context_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
+ The mask position ids indicating which encoder output patches are going to be exposed to the predictor.
+ By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating full context
+ available to the predictor.
+ target_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
+ The mask position ids indicating which encoder output patches are going to be used as a prediction target
+ for the predictor. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating
+ that the predictor should predict all encoder patches.
+ skip_predictor (bool):
+ flag to skip the predictor forward, useful if you just need the encoder outputs
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if pixel_values_videos is None:
+ raise ValueError("You have to specify pixel_values_videos")
+
+ # Prepare head mask if needed
+ context_head_mask = _convert_head_mask_to_5d(context_head_mask, self.config.num_hidden_layers)
+ target_head_mask = _convert_head_mask_to_5d(target_head_mask, self.config.pred_num_hidden_layers)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ pixel_values_videos=pixel_values_videos,
+ head_mask=context_head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ sequence_output = encoder_outputs.last_hidden_state
+
+ if context_mask is None and target_mask is None:
+ B = pixel_values_videos.size(0)
+ N = sequence_output.size(1) # ensure we are using dynamic patch size
+ context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
+ target_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
+
+ if not skip_predictor:
+ predictor_outputs: BaseModelOutput = self.predictor(
+ encoder_hidden_states=sequence_output,
+ context_mask=context_mask,
+ target_mask=target_mask,
+ head_mask=target_head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ predictor_output = VJEPA2WithMaskedInputPredictorOutput(
+ last_hidden_state=predictor_outputs.last_hidden_state,
+ target_hidden_state=apply_masks(sequence_output, target_mask),
+ hidden_states=predictor_outputs.hidden_states,
+ attentions=predictor_outputs.attentions,
+ )
+ else:
+ predictor_output = None
+
+ encoder_output = VJEPA2WithMaskedInputModelOutput(
+ last_hidden_state=sequence_output,
+ masked_hidden_state=apply_masks(sequence_output, context_mask),
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ predictor_output=predictor_output,
+ )
+
+ return encoder_output
+
+ def get_vision_features(self, pixel_values_videos) -> torch.Tensor:
+ encoder_output = self.forward(pixel_values_videos)
+ return encoder_output.last_hidden_state
+
+
+__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel"]
diff --git a/src/transformers/models/vjepa2/video_processing_vjepa2.py b/src/transformers/models/vjepa2/video_processing_vjepa2.py
new file mode 100644
index 0000000000..2df100f7eb
--- /dev/null
+++ b/src/transformers/models/vjepa2/video_processing_vjepa2.py
@@ -0,0 +1,59 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Video processor class for VJEPA2."""
+
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+)
+from ...processing_utils import Unpack, VideosKwargs
+from ...utils import is_vision_available
+from ...utils.import_utils import requires
+from ...video_processing_utils import BaseVideoProcessor
+
+
+if is_vision_available():
+ from ...image_utils import PILImageResampling
+
+
+class VJEPA2VideoProcessorInitKwargs(VideosKwargs): ...
+
+
+@requires(backends=("torchvision",))
+class VJEPA2VideoProcessor(BaseVideoProcessor):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_DEFAULT_MEAN
+ image_std = IMAGENET_DEFAULT_STD
+ size = {"shortest_edge": int(256 * 256 / 224)}
+ crop_size = 256
+ do_resize = True
+ do_rescale = True
+ do_center_crop = True
+ do_normalize = True
+ valid_kwargs = VJEPA2VideoProcessorInitKwargs
+ model_input_names = ["pixel_values_videos"]
+
+ def __init__(self, **kwargs: Unpack[VJEPA2VideoProcessorInitKwargs]):
+ crop_size = kwargs.get("crop_size", 256)
+ if not isinstance(crop_size, int):
+ if not isinstance(crop_size, dict) or "height" not in crop_size:
+ raise ValueError("crop_size must be an integer or a dictionary with a 'height' key")
+ crop_size = crop_size["height"]
+ resize_size = int(crop_size * 256 / 224)
+ kwargs["size"] = {"shortest_edge": resize_size}
+ super().__init__(**kwargs)
+
+
+__all__ = ["VJEPA2VideoProcessor"]
diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py
index 28386b9f44..884dab4720 100755
--- a/src/transformers/utils/fx.py
+++ b/src/transformers/utils/fx.py
@@ -166,6 +166,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"t5",
"trocr",
"vit",
+ "vjepa2",
"xglm",
"wav2vec2",
# "xlnet",
diff --git a/tests/models/vjepa2/__init__.py b/tests/models/vjepa2/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/models/vjepa2/test_modeling_vjepa2.py b/tests/models/vjepa2/test_modeling_vjepa2.py
new file mode 100644
index 0000000000..8a4b55ad6a
--- /dev/null
+++ b/tests/models/vjepa2/test_modeling_vjepa2.py
@@ -0,0 +1,345 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch V-JEPA2 model."""
+
+import unittest
+
+import numpy as np
+
+from transformers import VJEPA2Config
+from transformers.testing_utils import (
+ is_flaky,
+ require_torch,
+ require_vision,
+ slow,
+ torch_device,
+)
+from transformers.utils import cached_property, is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+from ...test_video_processing_common import (
+ prepare_video_inputs,
+)
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import VJEPA2Model
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import AutoVideoProcessor
+
+VJEPA_HF_MODEL = "facebook/vjepa2-vitl-fpc64-256"
+
+
+class VJEPA2ModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=2,
+ image_size=16,
+ patch_size=16,
+ num_channels=3,
+ hidden_size=32,
+ num_hidden_layers=4,
+ num_attention_heads=2,
+ num_frames=2,
+ mlp_ratio=1,
+ pred_hidden_size=32,
+ pred_num_attention_heads=2,
+ pred_num_hidden_layers=2,
+ pred_num_mask_tokens=10,
+ is_training=False,
+ attn_implementation="sdpa",
+ mask_ratio=0.5,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_frames = num_frames
+ self.mlp_ratio = mlp_ratio
+ self.pred_hidden_size = pred_hidden_size
+ self.pred_num_attention_heads = pred_num_attention_heads
+ self.pred_num_hidden_layers = pred_num_hidden_layers
+ self.pred_num_mask_tokens = pred_num_mask_tokens
+ self.attn_implementation = attn_implementation
+ self.is_training = is_training
+ self.mask_ratio = mask_ratio
+
+ num_patches = ((image_size // patch_size) ** 2) * (num_frames // 2)
+ self.seq_length = num_patches
+ self.num_masks = int(self.mask_ratio * self.seq_length)
+ self.mask_length = num_patches
+
+ def prepare_config_and_inputs(self):
+ pixel_values_videos = floats_tensor(
+ [
+ self.batch_size,
+ self.num_frames,
+ self.num_channels,
+ self.image_size,
+ self.image_size,
+ ]
+ )
+
+ config = self.get_config()
+
+ return config, pixel_values_videos
+
+ def get_config(self):
+ return VJEPA2Config(
+ crop_size=self.image_size,
+ frames_per_clip=self.num_frames,
+ hidden_size=self.hidden_size,
+ num_attention_heads=self.num_attention_heads,
+ num_hidden_layers=self.num_hidden_layers,
+ mlp_ratio=self.mlp_ratio,
+ pred_hidden_size=self.pred_hidden_size,
+ pred_num_attention_heads=self.pred_num_attention_heads,
+ pred_num_hidden_layers=self.pred_num_hidden_layers,
+ pred_num_mask_tokens=self.pred_num_mask_tokens,
+ )
+
+ def create_and_check_model(self, config, pixel_values_videos):
+ model = VJEPA2Model(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values_videos)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.seq_length, self.hidden_size),
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ pixel_values_videos,
+ ) = config_and_inputs
+ inputs_dict = {"pixel_values_videos": pixel_values_videos}
+ return config, inputs_dict
+
+
+@require_torch
+class VJEPA2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as VJEPA2 does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ test_torch_exportable = True
+
+ all_model_classes = (VJEPA2Model,) if is_torch_available() else ()
+
+ fx_compatible = True
+
+ pipeline_model_mapping = {}
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = VJEPA2ModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=VJEPA2Config, has_text_modality=False, hidden_size=37)
+
+ @is_flaky(max_attempts=3, description="`torch.nn.init.trunc_normal_` is flaky.")
+ def test_initialization(self):
+ super().test_initialization()
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="VJEPA2 does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(
+ reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ @unittest.skip(
+ reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing_use_reentrant(self):
+ pass
+
+ @unittest.skip(
+ reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing_use_reentrant_false(self):
+ pass
+
+ def test_model_get_set_embeddings(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="VJEPA2 does not support feedforward chunking yet")
+ def test_feed_forward_chunking(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+def prepare_random_video(image_size=256):
+ videos = prepare_video_inputs(
+ batch_size=1,
+ num_frames=16,
+ num_channels=3,
+ min_resolution=image_size,
+ max_resolution=image_size,
+ equal_resolution=True,
+ return_tensors="torch",
+ )
+ return videos
+
+
+@require_torch
+@require_vision
+class VJEPA2ModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_video_processor(self):
+ return AutoVideoProcessor.from_pretrained(VJEPA_HF_MODEL) if is_vision_available() else None
+
+ @slow
+ def test_inference_image(self):
+ model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
+
+ video_processor = self.default_video_processor
+ image = prepare_img()
+ inputs = video_processor(torch.Tensor(np.array(image)), return_tensors="pt").to(torch_device)
+ pixel_values_videos = inputs.pixel_values_videos
+ pixel_values_videos = pixel_values_videos.repeat(1, model.config.frames_per_clip, 1, 1, 1)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values_videos)
+
+ # verify the last hidden states
+ expected_shape = torch.Size((1, 8192, 1024))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[-0.0061, -1.8365, 2.7343], [-2.5938, -2.7181, -0.1663], [-1.7993, -2.2430, -1.1388]],
+ device=torch_device,
+ )
+ torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3)
+
+ @slow
+ def test_inference_video(self):
+ model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
+
+ video_processor = self.default_video_processor
+ video = prepare_random_video()
+ inputs = video_processor(video, return_tensors="pt").to(torch_device)
+ pixel_values_videos = inputs.pixel_values_videos
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values_videos)
+
+ # verify the last hidden states
+ expected_shape = torch.Size((1, 2048, 1024))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
+ @slow
+ def test_predictor_outputs(self):
+ model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
+
+ video_processor = self.default_video_processor
+ video = prepare_random_video()
+ inputs = video_processor(video, return_tensors="pt").to(torch_device)
+ pixel_values_videos = inputs.pixel_values_videos
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values_videos)
+
+ # verify the last hidden states
+ expected_shape = torch.Size((1, 2048, 1024))
+ self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
+
+ @slow
+ def test_predictor_full_mask(self):
+ model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
+
+ video_processor = self.default_video_processor
+ video = prepare_random_video()
+ inputs = video_processor(video, return_tensors="pt").to(torch_device)
+ pixel_values_videos = inputs.pixel_values_videos
+
+ # forward pass
+ with torch.no_grad():
+ context_mask = [torch.arange(2048, device=pixel_values_videos.device).unsqueeze(0)]
+ predictor_mask = context_mask
+ outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
+
+ # verify the last hidden states
+ expected_shape = torch.Size((1, 2048, 1024))
+ self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
+
+ @slow
+ def test_predictor_partial_mask(self):
+ model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device)
+
+ video_processor = self.default_video_processor
+ video = prepare_random_video()
+ inputs = video_processor(video, return_tensors="pt").to(torch_device)
+ pixel_values_videos = inputs.pixel_values_videos
+
+ num_patches = 2048
+ num_masks = 100
+ # forward pass
+ with torch.no_grad():
+ pos_ids = torch.arange(num_patches, device=pixel_values_videos.device)
+ context_mask = [pos_ids[0 : num_patches - num_masks].unsqueeze(0)]
+ predictor_mask = [pos_ids[num_patches - num_masks :].unsqueeze(0)]
+ outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask)
+
+ # verify the last hidden states
+ expected_shape = torch.Size((1, num_masks, 1024))
+ self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index bc62d56894..4aab7a69eb 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -1301,6 +1301,7 @@ class ModelTesterMixin:
"input_values",
"inputs_embeds",
"pixel_values",
+ "pixel_values_videos",
"token_type_ids",
"visual_feats",
"visual_pos",