From 84710a4291c3ca4d4b3d65d5a011ff83af243c1d Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 11 Jun 2025 15:00:08 +0100 Subject: [PATCH] Add V-JEPA 2 (#38746) * adding model and conversion scripts * add imports to test vjepa conversion * fix imports and make conversion work * fix computation for short side * replace attention with library attention function * cleanup more attention classes * remove config overrides * add test cases, fix some of the failing ones * fix the model outputs * fix outputs of the model per review * fix too big model test case * fix styling __init__.py * fix initialization test * remove all asserts per review * update sorting unsorting logic as per feedback * remove is_video per review * remove another is_video segment * remove unwanted stuff * small fixes * add docstrings for the model * revert adding vjepa2 config here * update styling * add config docstrings (wip) * fix dpr issue * removed test failing issues * update styles * merge predictor configs into main config * remove processing code, add video processor * remove permute which is not necessary now * fix styles * updated vjepa2 to be in video_processing_auto * update comment for preprocessing * test integration test and fix the outputs * update test values, change test to look at repeated frames for a given image * add a simple video processing test * refactoring pixel_values_videos and upload ckpts to original * fix torch_fx test cases * remove unused config * add all config docstrings * add more integration tests * add basic doc * revert unwanted styling changes * working make fixup * Fix model_type in config * update attention implementation to fit new hf standards * fix the preprocessing logic, ensure it matches the original model * remove use_rope logic, cleanup * fix docstrings * Further cleanup, update doc * Fix model prefix * fix get_vision_features * VJEPA2Embeddings style refactor * nit, style comment * change modules default values * Only `str` activation in config * GradientCheckpointingLayer * fixup * fix conversion script * Remove return_dict * remove None return typehint * Refactor VJEPA2Layer, remove use_SiLU * Fix fx tests * dpr -> drop_path_rates * move *ModelOutput on top * format docs bit * update docs * update docs * update doc example * remove prune_heads from model * remove unused config params * refactor embed signature * Add vjepa to docs * Fix config docstring * update defaults * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca * Fix import * Min refactoring * Update HUB_SOURCE and HUB_REPO in conversion script * Add missing headers * VJEPA -> V-JEPA in docs * Add image to doc * fix style * fix init weights * change checkpoint name in modeling tests --------- Co-authored-by: Koustuv Sinha Co-authored-by: yonigozlan Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Co-authored-by: Koustuv Sinha Co-authored-by: Pedro Cuenca --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/vjepa2.md | 82 ++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/auto/video_processing_auto.py | 1 + src/transformers/models/vjepa2/__init__.py | 29 + .../models/vjepa2/configuration_vjepa2.py | 146 +++ .../models/vjepa2/convert_vjepa2_to_hf.py | 346 +++++++ .../models/vjepa2/modeling_vjepa2.py | 903 ++++++++++++++++++ .../models/vjepa2/video_processing_vjepa2.py | 59 ++ src/transformers/utils/fx.py | 1 + tests/models/vjepa2/__init__.py | 0 tests/models/vjepa2/test_modeling_vjepa2.py | 345 +++++++ tests/test_modeling_common.py | 1 + 15 files changed, 1919 insertions(+) create mode 100644 docs/source/en/model_doc/vjepa2.md create mode 100644 src/transformers/models/vjepa2/__init__.py create mode 100644 src/transformers/models/vjepa2/configuration_vjepa2.py create mode 100644 src/transformers/models/vjepa2/convert_vjepa2_to_hf.py create mode 100644 src/transformers/models/vjepa2/modeling_vjepa2.py create mode 100644 src/transformers/models/vjepa2/video_processing_vjepa2.py create mode 100644 tests/models/vjepa2/__init__.py create mode 100644 tests/models/vjepa2/test_modeling_vjepa2.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 12e4224070..7916dd9a06 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -905,6 +905,8 @@ - sections: - local: model_doc/timesformer title: TimeSformer + - local: model_doc/vjepa2 + title: V-JEPA 2 - local: model_doc/videomae title: VideoMAE - local: model_doc/vivit diff --git a/docs/source/en/model_doc/vjepa2.md b/docs/source/en/model_doc/vjepa2.md new file mode 100644 index 0000000000..5ad02ae274 --- /dev/null +++ b/docs/source/en/model_doc/vjepa2.md @@ -0,0 +1,82 @@ + + + +
+
+ PyTorch + SDPA + FlashAttention +
+
+ +# V-JEPA 2 + +V-JEPA 2 is a self-supervised approach to training video encoders developed by FAIR, Meta. Using internet-scale video data, V-JEPA 2 attains state-of-the-art performance on motion understanding and human action anticipation tasks. V-JEPA 2-AC is a latent action-conditioned world model post-trained from V-JEPA 2 (using a small amount of robot trajectory interaction data) that solves robot manipulation tasks without environment-specific data collection or task-specific training or calibration. + +
+ drawing +
+ +You can find all original V-JEPA2 checkpoints under the [V-JEPA 2](https://huggingface.co/collections/facebook/v-jepa-2-6841bad8413014e185b497a6) collection. + + +This model was contributed by [koustuvs](https://huggingface.co/koustuvs), [yonigozlan](https://huggingface.co/yonigozlan) and [qubvel](https://huggingface.co/qubvel-hf). The original code can be found [here](https://github.com/facebookresearch/vjepa2). + +## Usage example + +The snippet below shows how to load the V-JEPA 2 model using the `AutoModel` class. + +```py +import torch +from torchcodec.decoders import VideoDecoder +import numpy as np + +processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc64-256") +model = AutoModel.from_pretrained( + "facebook/vjepa2-vitl-fpc64-256", + torch_dtype=torch.float16, + device_map="auto", + attn_implementation="sdpa" +) + +video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4" + +vr = VideoDecoder(video_url) +frame_idx = np.arange(0, 64) # choosing some frames. here, you can define more complex sampling strategy +video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W +video = processor(video, return_tensors="pt").to(model.device) +outputs = model(**video) + +# V-JEPA 2 encoder outputs, same as calling `model.get_vision_features()` +encoder_outputs = outputs.last_hidden_state + +# V-JEPA 2 predictor outputs +predictor_outputs = outputs.predictor_output.last_hidden_state +``` + +## VJEPA2Config + +[[autodoc]] VJEPA2Config + +## VJEPA2Model + +[[autodoc]] VJEPA2Model + - forward + +## VJEPA2VideoProcessor + +[[autodoc]] VJEPA2VideoProcessor diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index ea10dc8666..dea3c98c38 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -323,6 +323,7 @@ if TYPE_CHECKING: from .vitpose_backbone import * from .vits import * from .vivit import * + from .vjepa2 import * from .wav2vec2 import * from .wav2vec2_bert import * from .wav2vec2_conformer import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index da55432dfd..fe8a889b5d 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -365,6 +365,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str]( ("vitpose_backbone", "VitPoseBackboneConfig"), ("vits", "VitsConfig"), ("vivit", "VivitConfig"), + ("vjepa2", "VJEPA2Config"), ("wav2vec2", "Wav2Vec2Config"), ("wav2vec2-bert", "Wav2Vec2BertConfig"), ("wav2vec2-conformer", "Wav2Vec2ConformerConfig"), @@ -750,6 +751,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str]( ("vitpose_backbone", "ViTPoseBackbone"), ("vits", "VITS"), ("vivit", "ViViT"), + ("vjepa2", "VJEPA2Model"), ("wav2vec2", "Wav2Vec2"), ("wav2vec2-bert", "Wav2Vec2-BERT"), ("wav2vec2-conformer", "Wav2Vec2-Conformer"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 0998af5592..bcd56483c1 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -336,6 +336,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("vitdet", "VitDetModel"), ("vits", "VitsModel"), ("vivit", "VivitModel"), + ("vjepa2", "VJEPA2Model"), ("wav2vec2", "Wav2Vec2Model"), ("wav2vec2-bert", "Wav2Vec2BertModel"), ("wav2vec2-conformer", "Wav2Vec2ConformerModel"), diff --git a/src/transformers/models/auto/video_processing_auto.py b/src/transformers/models/auto/video_processing_auto.py index 1688b7fbeb..0f48dcdac7 100644 --- a/src/transformers/models/auto/video_processing_auto.py +++ b/src/transformers/models/auto/video_processing_auto.py @@ -56,6 +56,7 @@ else: ("qwen2_vl", "Qwen2VLVideoProcessor"), ("smolvlm", "SmolVLMVideoProcessor"), ("video_llava", "VideoLlavaVideoProcessor"), + ("vjepa2", "VJEPA2VideoProcessor"), ] ) diff --git a/src/transformers/models/vjepa2/__init__.py b/src/transformers/models/vjepa2/__init__.py new file mode 100644 index 0000000000..f184d058a8 --- /dev/null +++ b/src/transformers/models/vjepa2/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vjepa2 import * + from .modeling_vjepa2 import * + from .video_processing_vjepa2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/vjepa2/configuration_vjepa2.py b/src/transformers/models/vjepa2/configuration_vjepa2.py new file mode 100644 index 0000000000..4571b88602 --- /dev/null +++ b/src/transformers/models/vjepa2/configuration_vjepa2.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VJEPA 2 model configuration""" + +from ...configuration_utils import PretrainedConfig + + +class VJEPA2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VJEPA2Model`]. It is used to instantiate an + VJEPA2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the VJEPA2 + [facebook/vjepa2-vitl-fpc64-256](https://huggingface.co/facebook/vjepa2-vitl-fpc64-256) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + crop_size (`int`, *optional*, defaults to 256): + Input resolution of the model + frames_per_clip (`int`, *optional*, defaults to 64): + The number of frames the model has been pretrained with. Does not impact inference. + tubelet_size (`int`, *optional*, defaults to 2): + The number of temporal frames used for a single rastor, check paper for more information. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers + in_chans (`int`, *optional*, defaults to 3): + The number of input channels + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Encoder + num_hidden_layers (`int`, *optional*, defaults to 24): + The number of hidden layers + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of the hidden size of the MLPs used in Encoder relative to the `hidden_size`. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for attentions. + The dropout probability for all fully connected layers. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + pred_hidden_size (`int`, *optional*, defaults to 384): + Dimensionality of the predictor layers + pred_num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Predictor + pred_num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Predictor + pred_num_mask_tokens (`int`, *optional*, defaults to 10): + Define the number of mask tokens to use in the Predictor + pred_zero_init_mask_tokens (`bool`, *optional*, defaults to `True`): + Initialize the mask tokens in the predictor with 0. + pred_mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of the hidden size of the MLPs used in Predictor relative to the `pred_hidden_size`. + + Example: + + ```python + >>> from transformers import VJEPA2Config, VJEPA2Model + + >>> # Initializing a VJEPA2 vjepa2-vitl-fpc64-256 style configuration + >>> configuration = VJEPA2Config() + + >>> # Initializing a model (with random weights) from the vjepa2-vitl-fpc64-256 style configuration + >>> model = VJEPA2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vjepa2" + + def __init__( + self, + patch_size=16, + crop_size=256, + frames_per_clip=64, + tubelet_size=2, + hidden_size=1024, + in_chans=3, + num_attention_heads=16, + num_hidden_layers=24, + drop_path_rate=0.0, + mlp_ratio=4.0, + layer_norm_eps=1e-6, + qkv_bias=True, + attention_probs_dropout_prob=0.0, + hidden_act="gelu", + initializer_range=0.02, + # predictor params + pred_hidden_size=384, + pred_num_attention_heads=12, + pred_num_hidden_layers=12, + pred_num_mask_tokens=10, + pred_zero_init_mask_tokens=True, + pred_mlp_ratio=4.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.crop_size = crop_size + self.frames_per_clip = frames_per_clip + self.patch_size = patch_size + self.tubelet_size = tubelet_size + self.hidden_size = hidden_size + self.in_chans = in_chans + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.drop_path_rate = drop_path_rate + self.mlp_ratio = mlp_ratio + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.image_size = crop_size + # predictor params + self.pred_hidden_size = pred_hidden_size + self.pred_num_attention_heads = pred_num_attention_heads + self.pred_num_hidden_layers = pred_num_hidden_layers + self.pred_num_mask_tokens = pred_num_mask_tokens + self.pred_zero_init_mask_tokens = pred_zero_init_mask_tokens + self.pred_mlp_ratio = pred_mlp_ratio + + +__all__ = ["VJEPA2Config"] diff --git a/src/transformers/models/vjepa2/convert_vjepa2_to_hf.py b/src/transformers/models/vjepa2/convert_vjepa2_to_hf.py new file mode 100644 index 0000000000..527dbc35d9 --- /dev/null +++ b/src/transformers/models/vjepa2/convert_vjepa2_to_hf.py @@ -0,0 +1,346 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import tempfile +from pathlib import Path + +import numpy as np +import requests +import torch +from huggingface_hub import HfApi +from PIL import Image + +from transformers import VJEPA2Config, VJEPA2Model, VJEPA2VideoProcessor +from transformers.models.vjepa2.modeling_vjepa2 import apply_masks + + +HUB_REPO = "https://github.com/facebookresearch/vjepa2" +HUB_SOURCE = "github" + +HUB_MODELS = { + "vit_large": "facebook/vjepa2-vitl-fpc64-256", + "vit_huge": "facebook/vjepa2-vith-fpc64-256", + "vit_giant": "facebook/vjepa2-vitg-fpc64-256", + "vit_giant_384": "facebook/vjepa2-vitg-fpc64-384", +} + +S3_MODELS = { + "vit_large": "https://dl.fbaipublicfiles.com/vjepa2/vitl.pt", + "vit_huge": "https://dl.fbaipublicfiles.com/vjepa2/vith.pt", + "vit_giant": "https://dl.fbaipublicfiles.com/vjepa2/vitg.pt", + "vit_giant_384": "https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt", +} + +TOKEN = os.environ.get("HF_TOKEN", None) + + +def get_vjepa2_config(model_name): + # size of the architecture + if model_name == "vit_large": + return VJEPA2Config( + crop_size=256, + frames_per_clip=64, + hidden_size=1024, + num_attention_heads=16, + num_hidden_layers=24, + mlp_ratio=4, + pred_hidden_size=384, + pred_num_attention_heads=12, + pred_num_hidden_layers=12, + pred_num_mask_tokens=10, + ) + elif model_name == "vit_huge": + return VJEPA2Config( + crop_size=256, + frames_per_clip=64, + hidden_size=1280, + num_attention_heads=16, + num_hidden_layers=32, + mlp_ratio=4, + pred_hidden_size=384, + pred_num_attention_heads=12, + pred_num_hidden_layers=12, + pred_num_mask_tokens=10, + ) + elif model_name == "vit_giant": + return VJEPA2Config( + crop_size=256, + frames_per_clip=64, + hidden_size=1408, + num_attention_heads=22, + num_hidden_layers=40, + mlp_ratio=48 / 11, + pred_hidden_size=384, + pred_num_attention_heads=12, + pred_num_hidden_layers=12, + pred_num_mask_tokens=10, + ) + elif model_name == "vit_giant_384": + return VJEPA2Config( + crop_size=384, + frames_per_clip=64, + hidden_size=1408, + num_attention_heads=22, + num_hidden_layers=40, + mlp_ratio=48 / 11, + pred_hidden_size=384, + pred_num_attention_heads=12, + pred_num_hidden_layers=12, + pred_num_mask_tokens=10, + ) + else: + raise ValueError("Model not supported") + + +def convert_encoder_keys(model_state_dict, og_encoder_state_dict, config): + emb_dim = config.hidden_size + for key, val in og_encoder_state_dict.copy().items(): + val = og_encoder_state_dict.pop(key) + key = key.replace("module.backbone.", "") + if key.startswith("blocks."): + key = key.replace("blocks.", "encoder.layer.") + if "attn." in key: + key = key.replace("attn.", "attention.") + if key == "pos_embed": + key = "encoder.embeddings.position_embeddings" + if "patch_embed." in key: + key = key.replace("patch_embed.", "encoder.embeddings.patch_embeddings.") + if key.startswith("norm."): + key = key.replace("norm.", "encoder.layernorm.") + if "qkv." in key: + prefix, suffix = key.split("qkv") + if "bias" in suffix: + q_e, k_e, v_e = ( + val[0:emb_dim], + val[emb_dim : emb_dim * 2], + val[emb_dim * 2 :], + ) + else: + q_e, k_e, v_e = ( + val[0:emb_dim, :], + val[emb_dim : emb_dim * 2, :], + val[emb_dim * 2 :, :], + ) + og_encoder_state_dict[prefix + "query" + suffix] = q_e + og_encoder_state_dict[prefix + "key" + suffix] = k_e + og_encoder_state_dict[prefix + "value" + suffix] = v_e + else: + og_encoder_state_dict[key] = val + return og_encoder_state_dict + + +def convert_predictor_keys(model_state_dict, og_predictor_state_dict, config): + emb_dim = config.pred_hidden_size + if "predictor_pos_embed" in og_predictor_state_dict: + del og_predictor_state_dict["predictor_pos_embed"] + # update predictor weights + mask_tokens = {} + mask_token_keys_to_delete = [] + for key, val in og_predictor_state_dict.copy().items(): + val = og_predictor_state_dict.pop(key) + key = key.replace("module.backbone.", "") + if key.startswith("predictor_blocks."): + key = key.replace("predictor_blocks.", "predictor.layer.") + if "attn." in key: + key = key.replace("attn.", "attention.") + if key == "predictor_pos_embed": + key = "predictor.embeddings.position_embeddings" + if "predictor_embed." in key: + key = key.replace("predictor_embed.", "predictor.embeddings.predictor_embeddings.") + if "mask_tokens." in key: + mask_tokens[key.split("mask_tokens.")[-1]] = val + mask_token_keys_to_delete.append(key) + # key = key.replace("mask_tokens.", "predictor.embeddings.mask_tokens.") + if key.startswith("predictor_norm."): + key = key.replace("predictor_norm.", "predictor.layernorm.") + if key.startswith("predictor_proj."): + key = key.replace("predictor_proj.", "predictor.proj.") + if "qkv." in key: + prefix, suffix = key.split("qkv") + if "bias" in suffix: + q_e, k_e, v_e = ( + val[0:emb_dim], + val[emb_dim : emb_dim * 2], + val[emb_dim * 2 :], + ) + else: + q_e, k_e, v_e = ( + val[0:emb_dim, :], + val[emb_dim : emb_dim * 2, :], + val[emb_dim * 2 :, :], + ) + og_predictor_state_dict[prefix + "query" + suffix] = q_e + og_predictor_state_dict[prefix + "key" + suffix] = k_e + og_predictor_state_dict[prefix + "value" + suffix] = v_e + else: + og_predictor_state_dict[key] = val + mask_tokens = torch.stack([mask_tokens[f"{i}"] for i in range(len(mask_tokens))], dim=0) + for k in mask_token_keys_to_delete: + del og_predictor_state_dict[k] + og_predictor_state_dict["predictor.embeddings.mask_tokens"] = mask_tokens + return og_predictor_state_dict + + +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + return image + + +def upload_original_ckpts(model_name): + hf_repo = HUB_MODELS[model_name] + original_ckpt = S3_MODELS[model_name] + print(f"Uploading original checkpoint for vjepa2 {model_name} to {hf_repo}/original/") + with tempfile.NamedTemporaryFile() as fn: + local_path = fn.name + torch.hub.download_url_to_file(original_ckpt, local_path) + api = HfApi() + api.upload_file( + repo_id=hf_repo, + path_or_fileobj=local_path, + path_in_repo="original/model.pth", + repo_type="model", + token=TOKEN, + ) + print("Uploading complete") + + +@torch.no_grad() +def convert_and_test_vjepa2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our VJEPA2 structure. + """ + config = get_vjepa2_config(model_name) + + # load original model from torch hub + original_encoder, original_predictor = torch.hub.load(HUB_REPO, "vjepa2_" + model_name, source=HUB_SOURCE) + original_encoder.eval() + original_predictor.eval() + original_preprocessor = torch.hub.load( + HUB_REPO, "vjepa2_preprocessor", source=HUB_SOURCE, crop_size=config.crop_size + ) + + # load state_dict of original model, remove and rename some keys + encoder_state_dict = original_encoder.state_dict() + decoder_state_dict = original_predictor.state_dict() + + model = VJEPA2Model(config).eval() + state_dict = model.state_dict() + + og_encoder_sd = convert_encoder_keys(state_dict, encoder_state_dict, config) + og_predictor_sd = convert_predictor_keys(state_dict, decoder_state_dict, config) + + og_state_dict = og_encoder_sd + og_state_dict.update(og_predictor_sd) + model.load_state_dict(og_state_dict) + + # load image + image = prepare_img() + image = torch.Tensor(np.array(image)).unsqueeze(0).permute(0, 3, 1, 2) + print("Input shape: ", image.shape) + + crop_size = config.crop_size + processor = VJEPA2VideoProcessor(crop_size=crop_size) + pr_out = processor(image, return_tensors="pt") + pixel_values_videos = pr_out.pixel_values_videos + # run original preprocessor + original_pixel_values = original_preprocessor(image) + assert original_pixel_values[0].permute(1, 0, 2, 3).shape == pixel_values_videos[0].shape + assert torch.allclose(original_pixel_values[0].permute(1, 0, 2, 3), pixel_values_videos[0], atol=1e-3) + + with torch.no_grad(): + # reshape and move to gpu + if pixel_values_videos.size(1) == 1: + pixel_values_videos = pixel_values_videos.repeat(1, config.frames_per_clip, 1, 1, 1) + # pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4) # B x C x T x H x W + pixel_values_videos = pixel_values_videos.to(device="cuda", dtype=torch.float32) + original_encoder = original_encoder.to(device="cuda", dtype=torch.float32) + original_predictor = original_predictor.to(device="cuda", dtype=torch.float32) + model = model.to(device="cuda", dtype=torch.float32) + # forward + original_encoder_outputs = original_encoder(pixel_values_videos.permute(0, 2, 1, 3, 4)) + B, N, _ = original_encoder_outputs.shape + # test full mask + context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))] + predictor_mask = context_mask + original_predictor_outputs = original_predictor(original_encoder_outputs, context_mask, predictor_mask) + outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask) + assert torch.allclose(outputs.last_hidden_state, original_encoder_outputs, atol=1e-3) + predictor_outputs = outputs.predictor_output + assert torch.allclose(predictor_outputs.last_hidden_state, original_predictor_outputs, atol=1e-3) + # test partial mask + window_size = 256 + mask = torch.arange(N, device=pixel_values_videos.device).unsqueeze(0) + context_mask = [mask[:, :window_size].repeat((B, 1))] + predictor_mask = [mask[:, window_size : window_size * 2].repeat((B, 1))] + original_predictor_outputs = original_predictor( + apply_masks(original_encoder_outputs, context_mask), + context_mask, + predictor_mask, + ) + outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask) + assert torch.allclose(outputs.last_hidden_state, original_encoder_outputs, atol=1e-3) + predictor_outputs = outputs.predictor_output + assert torch.allclose(predictor_outputs.last_hidden_state, original_predictor_outputs, atol=1e-3) + + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + name = HUB_MODELS[model_name] + model.push_to_hub(name, private=True) + processor.push_to_hub(name, private=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="vit_large", + type=str, + choices=[ + "vit_large", + "vit_huge", + "vit_giant", + "vit_giant_384", + ], + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the converted model to the 🤗 hub.", + ) + parser.add_argument("--upload_original", action="store_true", help="upload the original checkpoint") + + args = parser.parse_args() + convert_and_test_vjepa2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) + if args.upload_original: + upload_original_ckpts(args.model_name) diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py new file mode 100644 index 0000000000..7a3a95b129 --- /dev/null +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -0,0 +1,903 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, dataclass +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging +from .configuration_vjepa2 import VJEPA2Config + + +logger = logging.get_logger(__name__) + + +@dataclass +class VJEPA2WithMaskedInputPredictorOutput(ModelOutput): + """ + VJEPA Predictor outputs that also contains the masked encoder outputs + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + masked_hidden_state (`torch.FloatTensor`), *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + target_hidden_state (`torch.FloatTensor`), *optional*): + Returned when `target_mask` is provided which is applied on VJEPA2Encoder outputs. + """ + + last_hidden_state: torch.FloatTensor + masked_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + target_hidden_state: Optional[torch.FloatTensor] = None + + +@dataclass +class VJEPA2WithMaskedInputModelOutput(ModelOutput): + """ + VJEPA outputs that also contains the masked encoder outputs + Optionally contains the predictor outputs + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + masked_hidden_state (`torch.FloatTensor`), *optional*): + Returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + predictor_output (`VJEPA2WithMaskedInputPredictorOutput`, *optional*): + Returns the output from the Predictor module + """ + + last_hidden_state: torch.FloatTensor + masked_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + predictor_output: Optional[VJEPA2WithMaskedInputPredictorOutput] = None + + def to_tuple(self): + output = list(super().to_tuple()) + if isinstance(output[-1], VJEPA2WithMaskedInputPredictorOutput): + output[-1] = output[-1].to_tuple() + return tuple(output) + + +class VJEPA2PatchEmbeddings3D(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__( + self, + config: VJEPA2Config, + hidden_size: int = 1024, + ): + super().__init__() + self.patch_size = config.patch_size + self.tubelet_size = config.tubelet_size + self.hidden_size = hidden_size + + self.proj = nn.Conv3d( + in_channels=config.in_chans, + out_channels=hidden_size, + kernel_size=(config.tubelet_size, config.patch_size, config.patch_size), + stride=(config.tubelet_size, config.patch_size, config.patch_size), + ) + + @staticmethod + def num_patches(config): + return ( + (config.frames_per_clip // config.tubelet_size) + * (config.crop_size // config.patch_size) + * (config.crop_size // config.patch_size) + ) + + def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor: + x = self.proj(pixel_values_videos).flatten(2).transpose(1, 2) + return x + + +class VJEPA2Embeddings(nn.Module): + """ + Construct mask token, position and patch embeddings. + """ + + def __init__(self, config: VJEPA2Config, hidden_size: int = 1024): + super().__init__() + + self.config = config + self.hidden_size = hidden_size + self.patch_embeddings = VJEPA2PatchEmbeddings3D(config, hidden_size=hidden_size) + + self.num_patches = self.patch_embeddings.num_patches + self.patch_size = config.patch_size + + def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor: + num_frames = pixel_values_videos.shape[1] + + # Swap `frames` and `channels` dims, the result is: + # (batch_size, channels, num_frames, height, width) + pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4) + + # For some cases, if the input vision (image/video) consists of num_frames < tubelet_size, + # then embedding lookup fails. In these cases, we duplicate the frames. + if num_frames < self.config.tubelet_size: + pixel_values_videos = pixel_values_videos.repeat(1, 1, self.config.tubelet_size, 1, 1) + + target_dtype = self.patch_embeddings.proj.weight.dtype + pixel_values_videos = pixel_values_videos.to(dtype=target_dtype) + embeddings = self.patch_embeddings(pixel_values_videos) + + return embeddings + + +# Adapted from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_queries_or_keys(x, pos): + B, num_heads, N, D = x.size() + + # similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + # they are computing this every time. instead HF style is to compute the inv_freq once and store it + # -- compute angle for each position + omega = torch.arange(D // 2, dtype=x.dtype, device=x.device) + omega /= D / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + freq = torch.einsum("..., f -> ... f", pos, omega) # (..., N, D/2), outer product + + # -- build rotation matrix and apply + emb_sin = freq.sin() # (..., N, D/2) + emb_cos = freq.cos() # (..., N, D/2) + + emb_sin = emb_sin.squeeze(-1).repeat(1, 1, 1, 2) + emb_cos = emb_cos.squeeze(-1).repeat(1, 1, 1, 2) + + # -- + y = x.unflatten(-1, (-1, 2)) + y1, y2 = y.unbind(dim=-1) + + y = torch.stack((-y2, y1), dim=-1) + y = y.flatten(-2) + return (x * emb_cos) + (y * emb_sin) + + +class VJEPA2RopeAttention(nn.Module): + def __init__( + self, + config: VJEPA2Config, + hidden_size: int = 1024, + num_attention_heads: int = 16, + ): + super().__init__() + self.config = config + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"The hidden size {(hidden_size,)} is not a multiple of the number of attention " + f"heads {num_attention_heads}." + ) + + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.proj = nn.Linear(hidden_size, hidden_size) + self.dropout_prob = config.attention_probs_dropout_prob + self.dropout = nn.Dropout(self.dropout_prob) + + self.grid_size = self.config.crop_size // self.config.patch_size + self.grid_depth = self.config.frames_per_clip // self.config.tubelet_size + + self.d_dim = int(2 * ((self.attention_head_size // 3) // 2)) + self.h_dim = int(2 * ((self.attention_head_size // 3) // 2)) + self.w_dim = int(2 * ((self.attention_head_size // 3) // 2)) + + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def _get_frame_pos(self, ids): + tokens_per_frame = int(self.grid_size * self.grid_size) + return ids // tokens_per_frame + + def _get_height_pos(self, ids): + # Remove frame component from ids + tokens_per_frame = int(self.grid_size * self.grid_size) + frame_ids = self._get_frame_pos(ids) + ids = ids - tokens_per_frame * frame_ids + # -- + tokens_per_row = self.grid_size + return ids // tokens_per_row + + def get_position_ids(self, x, masks=None): + device = x.device + token_size = x.size(1) + + # Note: when masks is none, we use a 1d id instead of Bxnum_attention_heads mask, + # as 1d vector is broadcasted to the correct shapes. + if masks is not None: + ids = masks.unsqueeze(1).repeat(1, self.num_attention_heads, 1) + else: + ids = torch.arange(token_size, device=device) + # change to allow for extrapolation + tokens_per_frame = int(self.grid_size * self.grid_size) + frame_ids = self._get_frame_pos(ids) + # -- + tokens_per_row = self.grid_size + height_ids = self._get_height_pos(ids) + # -- + # Remove frame component from ids (1st term) and height component (2nd term) + width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids + return frame_ids, height_ids, width_ids + + def apply_rotary_embeddings(self, qk, pos_ids): + d_mask, h_mask, w_mask = pos_ids + s = 0 + qkd = rotate_queries_or_keys(qk[..., s : s + self.d_dim], pos=d_mask) + s += self.d_dim + qkh = rotate_queries_or_keys(qk[..., s : s + self.h_dim], pos=h_mask) + s += self.h_dim + qkw = rotate_queries_or_keys(qk[..., s : s + self.w_dim], pos=w_mask) + s += self.w_dim + # Combine rotated dimension + if s < self.attention_head_size: + qkr = qk[..., s:] + qk = torch.cat([qkd, qkh, qkw, qkr], dim=-1) + else: + qk = torch.cat([qkd, qkh, qkw], dim=-1) + return qk + + def forward( + self, + hidden_states, + position_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + head_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + pos_ids = self.get_position_ids(hidden_states, masks=position_mask) + key_layer = self.apply_rotary_embeddings(key_layer, pos_ids) + query_layer = self.apply_rotary_embeddings(query_layer, pos_ids) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = self.proj(context_layer.reshape(new_context_layer_shape)) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Adapted from transformers.models.beit.modeling_dinov2.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Adapted from transformers.models.beit.modeling_beit.BeitDropPath +class VJEPA2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class VJEPA2MLP(nn.Module): + def __init__(self, config: VJEPA2Config, hidden_size: int = 1024, mlp_ratio: float = 4.0): + super().__init__() + in_features = out_features = hidden_size + hidden_features = int(hidden_size * mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + self.activation = ACT2FN[config.hidden_act] + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class VJEPA2Layer(GradientCheckpointingLayer): + """This corresponds to the Block class in the original implementation.""" + + def __init__( + self, + config: VJEPA2Config, + drop_path_rate: float = 0.0, + hidden_size: int = 1024, + num_attention_heads: int = 16, + mlp_ratio: float = 4.0, + ): + super().__init__() + self.config = config + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + + self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + self.attention = VJEPA2RopeAttention(config, hidden_size, num_attention_heads) + self.drop_path = VJEPA2DropPath(drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + self.norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + self.mlp = VJEPA2MLP(config, hidden_size=hidden_size, mlp_ratio=mlp_ratio) + + def forward( + self, + hidden_states: torch.Tensor, + position_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + # Self-Attention + residual = hidden_states + hidden_states = self.norm1(hidden_states) + self_attention_outputs = self.attention( + hidden_states, + position_mask=position_mask, # position mask for context/target selection + head_mask=head_mask, # head mask is applied at F.scaled_dot_product_attention + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + hidden_states = self.drop_path(attention_output) + residual + + # MLP + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.drop_path(hidden_states) + residual + + # Add self attentions if we output attention weights + outputs = self_attention_outputs[1:] + outputs = (hidden_states,) + outputs + + return outputs + + +class VJEPA2Encoder(nn.Module): + def __init__(self, config: VJEPA2Config): + super().__init__() + self.config = config + + self.embeddings = VJEPA2Embeddings(config, hidden_size=config.hidden_size) + drop_path_rates = [ + (config.drop_path_rate * i / (config.num_hidden_layers - 1) if config.num_hidden_layers > 1 else 0.0) + for i in range(config.num_hidden_layers) + ] + self.layer = nn.ModuleList( + [ + VJEPA2Layer( + config, + drop_path_rate=drop_path_rates[i], + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + mlp_ratio=config.mlp_ratio, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + pixel_values_videos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> BaseModelOutput: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + hidden_states = self.embeddings(pixel_values_videos) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + layer_outputs = layer_module(hidden_states, None, layer_head_mask, output_attentions) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +def apply_masks(x, masks) -> torch.Tensor: + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + + return torch.cat(all_x, dim=0) + + +class VJEPA2PredictorEmbeddings(nn.Module): + """ + Construct mask token, position and patch embeddings. + """ + + def __init__(self, config: VJEPA2Config): + super().__init__() + + self.config = config + self.predictor_embeddings = nn.Linear(config.hidden_size, config.pred_hidden_size) + self.num_mask_tokens = 0 + self.zero_init_mask_tokens = config.pred_zero_init_mask_tokens + self.num_mask_tokens = config.pred_num_mask_tokens + self.mask_tokens = nn.Parameter(torch.zeros(self.num_mask_tokens, 1, 1, config.pred_hidden_size)) + + self.patch_size = config.patch_size + self.config = config + + @staticmethod + def num_patches(config): + if config.frames_per_clip > 1: + return ( + (config.frames_per_clip // config.tubelet_size) + * (config.crop_size // config.patch_size) + * (config.crop_size // config.patch_size) + ) + else: + return (config.crop_size // config.patch_size) * (config.crop_size // config.patch_size) + + def forward( + self, + hidden_states: torch.Tensor, + context_mask: List[torch.Tensor], + target_mask: List[torch.Tensor], + mask_index: int = 1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + hidden_states : encoder outputs (context) + context_mask: tokens of the context (outputs from the encoder) + target_mask: tokens to predict + mask_index: index of the target mask to choose (useful for multiclip?) + """ + + B = hidden_states.size(0) + context = self.predictor_embeddings(hidden_states) + + # Make target tokens + mask_index = mask_index % self.num_mask_tokens + target = self.mask_tokens[mask_index] + + # Note: this is problematic if the config isn't initialized with the right frames_per_clip value, + # e.g. for scenarios if we want to run predictor for more tokens than in the config. + # target = target.repeat(B, self.num_patches(self.config), 1) + # Remedy: use the provided target mask to get the max patch num + max_patch_num = target_mask[0].max() + 1 # one extra to include the last patch + target = target.repeat(B, max_patch_num, 1) + target = apply_masks(target, target_mask) + + # Concatenate context & target tokens + context = context.repeat(len(context_mask), 1, 1) + embeddings = torch.cat([context, target], dim=1) + + # Positions of context & target tokens + cm = torch.cat(context_mask, dim=0) + tm = torch.cat(target_mask, dim=0) + masks = torch.cat([cm, tm], dim=1) + + return embeddings, masks + + +class VJEPA2Predictor(nn.Module): + def __init__(self, config: VJEPA2Config): + super().__init__() + self.config = config + self.gradient_checkpointing = False + self.embeddings = VJEPA2PredictorEmbeddings(config) + drop_path_rates = [ + ( + config.drop_path_rate * i / (config.pred_num_hidden_layers - 1) + if config.pred_num_hidden_layers > 1 + else 0.0 + ) + for i in range(config.pred_num_hidden_layers) + ] + self.layer = nn.ModuleList( + [ + VJEPA2Layer( + config, + drop_path_rate=drop_path_rates[i], + hidden_size=config.pred_hidden_size, + num_attention_heads=config.pred_num_attention_heads, + mlp_ratio=config.pred_mlp_ratio, + ) + for i in range(config.pred_num_hidden_layers) + ] + ) + self.layernorm = nn.LayerNorm(config.pred_hidden_size, eps=config.layer_norm_eps) + self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True) + + def sort_tokens(self, hidden_states, position_masks, argsort, head_mask=None): + position_masks = torch.gather(position_masks, dim=1, index=argsort) + hidden_states = torch.gather( + hidden_states, + dim=1, + index=argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)), + ) + if head_mask is not None and head_mask[0] is not None: + head_mask = head_mask.permute(1, 0, 2, 3, 4) + argsort_4d = ( + argsort.unsqueeze(1) + .unsqueeze(1) + .expand(-1, head_mask.size(1), head_mask.size(2), -1) + .unsqueeze(-1) + .expand(-1, -1, -1, -1, head_mask.size(-1)) + ) + head_mask = torch.gather(head_mask, dim=3, index=argsort_4d) + argsort_5d = ( + argsort.unsqueeze(1) + .unsqueeze(1) + .unsqueeze(1) + .expand(-1, head_mask.size(1), head_mask.size(2), head_mask.size(3), -1) + ) + head_mask = torch.gather(head_mask, dim=4, index=argsort_5d) + head_mask = head_mask.permute(1, 0, 2, 3, 4) + return hidden_states, position_masks, head_mask + + def unsort_tokens(self, hidden_states, argsort): + reverse_argsort = torch.argsort(argsort, dim=1) + hidden_states = torch.gather( + hidden_states, + dim=1, + index=reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)), + ) + return hidden_states + + @can_return_tuple + def forward( + self, + encoder_hidden_states: torch.Tensor, + context_mask: List[torch.Tensor], + target_mask: List[torch.Tensor], + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> BaseModelOutput: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + # mask out the encoder hidden states + # this is implemented here as in VJEPA training a separate encoder is used for target + encoder_hidden_states = apply_masks(encoder_hidden_states, context_mask) + _, N_ctxt, D = encoder_hidden_states.shape + hidden_states, position_masks = self.embeddings(encoder_hidden_states, context_mask, target_mask) + + # Put tokens in sorted order + argsort = torch.argsort(position_masks, dim=1) # [B, N] + hidden_states, position_masks, head_mask = self.sort_tokens(hidden_states, position_masks, argsort, head_mask) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + layer_outputs = layer_module(hidden_states, position_masks, layer_head_mask, output_attentions) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.layernorm(hidden_states) + # unsort and extract the predicted tokens + hidden_states = self.unsort_tokens(hidden_states, argsort) + hidden_states = hidden_states[:, N_ctxt:] + # projection + hidden_states = self.proj(hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class VJEPA2PreTrainedModel(PreTrainedModel): + config_class = VJEPA2Config + base_model_prefix = "vjepa2" + main_input_name = "pixel_values_videos" + supports_gradient_checkpointing = True + _no_split_modules = ["VJEPA2Layer"] + _supports_sdpa = True + _supports_flash_attn_2 = True + + def _init_weights( + self, + module: Union[ + nn.Linear, + nn.Conv2d, + nn.LayerNorm, + VJEPA2Embeddings, + VJEPA2PredictorEmbeddings, + ], + ): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, VJEPA2PredictorEmbeddings): + if not module.zero_init_mask_tokens: + module.mask_token = nn.init.trunc_normal_( + module.mask_token.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.mask_token.dtype) + else: + module.mask_tokens.data.zero_() + + +def _convert_head_mask_to_5d(head_mask, num_hidden_layers): + """ + Inputs: + - head_mask: bsz x seq_length x seq_length | None + Returns + - [num_hidden_layers x batch x num_heads x seq_length x seq_length] | [num_hidden_layers] + """ + if head_mask is not None: + head_mask = head_mask.unsqueeze(1).unsqueeze(0) + head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) + else: + head_mask = [None] * num_hidden_layers + return head_mask + + +@auto_docstring +class VJEPA2Model(VJEPA2PreTrainedModel): + def __init__(self, config: VJEPA2Config): + super().__init__(config) + self.config = config + + self.encoder = VJEPA2Encoder(config) + self.predictor = VJEPA2Predictor(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> VJEPA2PatchEmbeddings3D: + return self.encoder.embeddings.patch_embeddings + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values_videos: torch.Tensor, + context_head_mask: Optional[torch.Tensor] = None, + context_mask: Optional[List[torch.Tensor]] = None, + target_head_mask: Optional[torch.Tensor] = None, + target_mask: Optional[List[torch.Tensor]] = None, + skip_predictor: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> VJEPA2WithMaskedInputModelOutput: + r""" + pixel_values_videos (`torch.Tensor` with shape `[batch size x num_frames x num_channels x height x width]`, required): + The input video pixels which is processed by VJEPA2VideoProcessor. + context_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the context. + target_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the target. + context_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*): + The mask position ids indicating which encoder output patches are going to be exposed to the predictor. + By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating full context + available to the predictor. + target_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*): + The mask position ids indicating which encoder output patches are going to be used as a prediction target + for the predictor. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating + that the predictor should predict all encoder patches. + skip_predictor (bool): + flag to skip the predictor forward, useful if you just need the encoder outputs + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if pixel_values_videos is None: + raise ValueError("You have to specify pixel_values_videos") + + # Prepare head mask if needed + context_head_mask = _convert_head_mask_to_5d(context_head_mask, self.config.num_hidden_layers) + target_head_mask = _convert_head_mask_to_5d(target_head_mask, self.config.pred_num_hidden_layers) + + encoder_outputs: BaseModelOutput = self.encoder( + pixel_values_videos=pixel_values_videos, + head_mask=context_head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = encoder_outputs.last_hidden_state + + if context_mask is None and target_mask is None: + B = pixel_values_videos.size(0) + N = sequence_output.size(1) # ensure we are using dynamic patch size + context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))] + target_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))] + + if not skip_predictor: + predictor_outputs: BaseModelOutput = self.predictor( + encoder_hidden_states=sequence_output, + context_mask=context_mask, + target_mask=target_mask, + head_mask=target_head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + predictor_output = VJEPA2WithMaskedInputPredictorOutput( + last_hidden_state=predictor_outputs.last_hidden_state, + target_hidden_state=apply_masks(sequence_output, target_mask), + hidden_states=predictor_outputs.hidden_states, + attentions=predictor_outputs.attentions, + ) + else: + predictor_output = None + + encoder_output = VJEPA2WithMaskedInputModelOutput( + last_hidden_state=sequence_output, + masked_hidden_state=apply_masks(sequence_output, context_mask), + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + predictor_output=predictor_output, + ) + + return encoder_output + + def get_vision_features(self, pixel_values_videos) -> torch.Tensor: + encoder_output = self.forward(pixel_values_videos) + return encoder_output.last_hidden_state + + +__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel"] diff --git a/src/transformers/models/vjepa2/video_processing_vjepa2.py b/src/transformers/models/vjepa2/video_processing_vjepa2.py new file mode 100644 index 0000000000..2df100f7eb --- /dev/null +++ b/src/transformers/models/vjepa2/video_processing_vjepa2.py @@ -0,0 +1,59 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Video processor class for VJEPA2.""" + +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, +) +from ...processing_utils import Unpack, VideosKwargs +from ...utils import is_vision_available +from ...utils.import_utils import requires +from ...video_processing_utils import BaseVideoProcessor + + +if is_vision_available(): + from ...image_utils import PILImageResampling + + +class VJEPA2VideoProcessorInitKwargs(VideosKwargs): ... + + +@requires(backends=("torchvision",)) +class VJEPA2VideoProcessor(BaseVideoProcessor): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"shortest_edge": int(256 * 256 / 224)} + crop_size = 256 + do_resize = True + do_rescale = True + do_center_crop = True + do_normalize = True + valid_kwargs = VJEPA2VideoProcessorInitKwargs + model_input_names = ["pixel_values_videos"] + + def __init__(self, **kwargs: Unpack[VJEPA2VideoProcessorInitKwargs]): + crop_size = kwargs.get("crop_size", 256) + if not isinstance(crop_size, int): + if not isinstance(crop_size, dict) or "height" not in crop_size: + raise ValueError("crop_size must be an integer or a dictionary with a 'height' key") + crop_size = crop_size["height"] + resize_size = int(crop_size * 256 / 224) + kwargs["size"] = {"shortest_edge": resize_size} + super().__init__(**kwargs) + + +__all__ = ["VJEPA2VideoProcessor"] diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 28386b9f44..884dab4720 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -166,6 +166,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ "t5", "trocr", "vit", + "vjepa2", "xglm", "wav2vec2", # "xlnet", diff --git a/tests/models/vjepa2/__init__.py b/tests/models/vjepa2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/vjepa2/test_modeling_vjepa2.py b/tests/models/vjepa2/test_modeling_vjepa2.py new file mode 100644 index 0000000000..8a4b55ad6a --- /dev/null +++ b/tests/models/vjepa2/test_modeling_vjepa2.py @@ -0,0 +1,345 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch V-JEPA2 model.""" + +import unittest + +import numpy as np + +from transformers import VJEPA2Config +from transformers.testing_utils import ( + is_flaky, + require_torch, + require_vision, + slow, + torch_device, +) +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin +from ...test_video_processing_common import ( + prepare_video_inputs, +) + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import VJEPA2Model + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoVideoProcessor + +VJEPA_HF_MODEL = "facebook/vjepa2-vitl-fpc64-256" + + +class VJEPA2ModelTester: + def __init__( + self, + parent, + batch_size=2, + image_size=16, + patch_size=16, + num_channels=3, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=2, + num_frames=2, + mlp_ratio=1, + pred_hidden_size=32, + pred_num_attention_heads=2, + pred_num_hidden_layers=2, + pred_num_mask_tokens=10, + is_training=False, + attn_implementation="sdpa", + mask_ratio=0.5, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_frames = num_frames + self.mlp_ratio = mlp_ratio + self.pred_hidden_size = pred_hidden_size + self.pred_num_attention_heads = pred_num_attention_heads + self.pred_num_hidden_layers = pred_num_hidden_layers + self.pred_num_mask_tokens = pred_num_mask_tokens + self.attn_implementation = attn_implementation + self.is_training = is_training + self.mask_ratio = mask_ratio + + num_patches = ((image_size // patch_size) ** 2) * (num_frames // 2) + self.seq_length = num_patches + self.num_masks = int(self.mask_ratio * self.seq_length) + self.mask_length = num_patches + + def prepare_config_and_inputs(self): + pixel_values_videos = floats_tensor( + [ + self.batch_size, + self.num_frames, + self.num_channels, + self.image_size, + self.image_size, + ] + ) + + config = self.get_config() + + return config, pixel_values_videos + + def get_config(self): + return VJEPA2Config( + crop_size=self.image_size, + frames_per_clip=self.num_frames, + hidden_size=self.hidden_size, + num_attention_heads=self.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + mlp_ratio=self.mlp_ratio, + pred_hidden_size=self.pred_hidden_size, + pred_num_attention_heads=self.pred_num_attention_heads, + pred_num_hidden_layers=self.pred_num_hidden_layers, + pred_num_mask_tokens=self.pred_num_mask_tokens, + ) + + def create_and_check_model(self, config, pixel_values_videos): + model = VJEPA2Model(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values_videos) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.seq_length, self.hidden_size), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values_videos, + ) = config_and_inputs + inputs_dict = {"pixel_values_videos": pixel_values_videos} + return config, inputs_dict + + +@require_torch +class VJEPA2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as VJEPA2 does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + test_torch_exportable = True + + all_model_classes = (VJEPA2Model,) if is_torch_available() else () + + fx_compatible = True + + pipeline_model_mapping = {} + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = VJEPA2ModelTester(self) + self.config_tester = ConfigTester(self, config_class=VJEPA2Config, has_text_modality=False, hidden_size=37) + + @is_flaky(max_attempts=3, description="`torch.nn.init.trunc_normal_` is flaky.") + def test_initialization(self): + super().test_initialization() + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="VJEPA2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip( + reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="VJEPA2 does not support feedforward chunking yet") + def test_feed_forward_chunking(self): + pass + + @slow + def test_model_from_pretrained(self): + model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +def prepare_random_video(image_size=256): + videos = prepare_video_inputs( + batch_size=1, + num_frames=16, + num_channels=3, + min_resolution=image_size, + max_resolution=image_size, + equal_resolution=True, + return_tensors="torch", + ) + return videos + + +@require_torch +@require_vision +class VJEPA2ModelIntegrationTest(unittest.TestCase): + @cached_property + def default_video_processor(self): + return AutoVideoProcessor.from_pretrained(VJEPA_HF_MODEL) if is_vision_available() else None + + @slow + def test_inference_image(self): + model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device) + + video_processor = self.default_video_processor + image = prepare_img() + inputs = video_processor(torch.Tensor(np.array(image)), return_tensors="pt").to(torch_device) + pixel_values_videos = inputs.pixel_values_videos + pixel_values_videos = pixel_values_videos.repeat(1, model.config.frames_per_clip, 1, 1, 1) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values_videos) + + # verify the last hidden states + expected_shape = torch.Size((1, 8192, 1024)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.0061, -1.8365, 2.7343], [-2.5938, -2.7181, -0.1663], [-1.7993, -2.2430, -1.1388]], + device=torch_device, + ) + torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3) + + @slow + def test_inference_video(self): + model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device) + + video_processor = self.default_video_processor + video = prepare_random_video() + inputs = video_processor(video, return_tensors="pt").to(torch_device) + pixel_values_videos = inputs.pixel_values_videos + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values_videos) + + # verify the last hidden states + expected_shape = torch.Size((1, 2048, 1024)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + @slow + def test_predictor_outputs(self): + model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device) + + video_processor = self.default_video_processor + video = prepare_random_video() + inputs = video_processor(video, return_tensors="pt").to(torch_device) + pixel_values_videos = inputs.pixel_values_videos + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values_videos) + + # verify the last hidden states + expected_shape = torch.Size((1, 2048, 1024)) + self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape) + + @slow + def test_predictor_full_mask(self): + model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device) + + video_processor = self.default_video_processor + video = prepare_random_video() + inputs = video_processor(video, return_tensors="pt").to(torch_device) + pixel_values_videos = inputs.pixel_values_videos + + # forward pass + with torch.no_grad(): + context_mask = [torch.arange(2048, device=pixel_values_videos.device).unsqueeze(0)] + predictor_mask = context_mask + outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask) + + # verify the last hidden states + expected_shape = torch.Size((1, 2048, 1024)) + self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape) + + @slow + def test_predictor_partial_mask(self): + model = VJEPA2Model.from_pretrained(VJEPA_HF_MODEL).to(torch_device) + + video_processor = self.default_video_processor + video = prepare_random_video() + inputs = video_processor(video, return_tensors="pt").to(torch_device) + pixel_values_videos = inputs.pixel_values_videos + + num_patches = 2048 + num_masks = 100 + # forward pass + with torch.no_grad(): + pos_ids = torch.arange(num_patches, device=pixel_values_videos.device) + context_mask = [pos_ids[0 : num_patches - num_masks].unsqueeze(0)] + predictor_mask = [pos_ids[num_patches - num_masks :].unsqueeze(0)] + outputs = model(pixel_values_videos, context_mask=context_mask, target_mask=predictor_mask) + + # verify the last hidden states + expected_shape = torch.Size((1, num_masks, 1024)) + self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bc62d56894..4aab7a69eb 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1301,6 +1301,7 @@ class ModelTesterMixin: "input_values", "inputs_embeds", "pixel_values", + "pixel_values_videos", "token_type_ids", "visual_feats", "visual_pos",