Add V-JEPA for video classification model (#38788)
* adding model and conversion scripts * add imports to test vjepa conversion * fix imports and make conversion work * fix computation for short side * replace attention with library attention function * cleanup more attention classes * remove config overrides * add test cases, fix some of the failing ones * fix the model outputs * fix outputs of the model per review * fix too big model test case * fix styling __init__.py * fix initialization test * remove all asserts per review * update sorting unsorting logic as per feedback * remove is_video per review * remove another is_video segment * remove unwanted stuff * small fixes * add docstrings for the model * revert adding vjepa2 config here * update styling * add config docstrings (wip) * fix dpr issue * removed test failing issues * update styles * merge predictor configs into main config * remove processing code, add video processor * remove permute which is not necessary now * fix styles * updated vjepa2 to be in video_processing_auto * update comment for preprocessing * test integration test and fix the outputs * update test values, change test to look at repeated frames for a given image * add a simple video processing test * refactoring pixel_values_videos and upload ckpts to original * fix torch_fx test cases * remove unused config * add all config docstrings * add more integration tests * add basic doc * revert unwanted styling changes * working make fixup * Fix model_type in config * Add ForVideoClassification model * update attention implementation to fit new hf standards * fix the preprocessing logic, ensure it matches the original model * remove use_rope logic, cleanup * fix docstrings * Further cleanup, update doc * Fix model prefix * fix get_vision_features * VJEPA2Embeddings style refactor * nit, style comment * change modules default values * Only `str` activation in config * GradientCheckpointingLayer * fixup * fix conversion script * Remove return_dict * remove None return typehint * Refactor VJEPA2Layer, remove use_SiLU * Fix fx tests * dpr -> drop_path_rates * move *ModelOutput on top * format docs bit * update docs * update docs * update doc example * remove prune_heads from model * remove unused config params * refactor embed signature * Add vjepa to docs * Fix config docstring * attention head * update defaults * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/model_doc/vjepa2.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Fix import * Min refactoring * Update HUB_SOURCE and HUB_REPO in conversion script * Add missing headers * VJEPA -> V-JEPA in docs * Add image to doc * fix style * fix init weights * change checkpoint name in modeling tests * Initial cls head setup * remove rop attention from head (not needed) * remove swigluffn - not needed * Add siglip layer * Replace with siglip layer * Rename Siglip - VJEPA2 * remove unused modules * remove siglip mlp * nit * remove MLP * Refactor head cross attention * refactor VJEPA2HeadCrossAttentionLayer * nit renaming * fixup * remove commented code * Add cls head params to config * depth from config * move pooler + classifier to the model * Update for cls model signature * move layers, rename a bit * fix docs * update weights init * remove typehint for init * add to auto-mapping * enable tests * Add conversion script * fixup * add to docs * fix docs * nit * refactor for mapping * clean * Add integration test * Fixing multi gpu test * update not-split-modules * update video cls test tolerance * Increase test_inference_image tolerance * Update no-split modules for multi gpu * Apply suggestions from code review * fixing multi-gpu * fix docstring * Add cls snippet to docs * Update checkpoint
This commit is contained in:
committed by
GitHub
parent
2ff964bcb4
commit
9bec2654ed
@@ -38,7 +38,7 @@ This model was contributed by [koustuvs](https://huggingface.co/koustuvs), [yoni
|
|||||||
|
|
||||||
## Usage example
|
## Usage example
|
||||||
|
|
||||||
The snippet below shows how to load the V-JEPA 2 model using the `AutoModel` class.
|
The snippet below shows how to load the V-JEPA 2 model for feature extraction using the `AutoModel` class.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
@@ -68,6 +68,43 @@ encoder_outputs = outputs.last_hidden_state
|
|||||||
predictor_outputs = outputs.predictor_output.last_hidden_state
|
predictor_outputs = outputs.predictor_output.last_hidden_state
|
||||||
```
|
```
|
||||||
|
|
||||||
|
V-JEPA 2 can also be finetuned for video classification. In the following snippet, we show how use finetuned on Something-Something-V2 video classification model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
from transformers import AutoVideoProcessor, AutoModelForVideoClassification
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
# Load model and video preprocessor
|
||||||
|
hf_repo = "facebook/vjepa2-vitl-fpc16-256-ssv2"
|
||||||
|
|
||||||
|
model = AutoModelForVideoClassification.from_pretrained(hf_repo).to(device)
|
||||||
|
processor = AutoVideoProcessor.from_pretrained(hf_repo)
|
||||||
|
|
||||||
|
# To load a video, sample the number of frames according to the model.
|
||||||
|
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4"
|
||||||
|
vr = VideoDecoder(video_url)
|
||||||
|
frame_idx = np.arange(0, model.config.frames_per_clip, 8) # you can define more complex sampling strategy
|
||||||
|
video = vr.get_frames_at(indices=frame_idx).data # frames x channels x height x width
|
||||||
|
|
||||||
|
# Preprocess and run inference
|
||||||
|
inputs = processor(video, return_tensors="pt").to(model.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
print("Top 5 predicted class names:")
|
||||||
|
top5_indices = logits.topk(5).indices[0]
|
||||||
|
top5_probs = torch.softmax(logits, dim=-1).topk(5).values[0]
|
||||||
|
for idx, prob in zip(top5_indices, top5_probs):
|
||||||
|
text_label = model.config.id2label[idx.item()]
|
||||||
|
print(f" - {text_label}: {prob:.2f}")
|
||||||
|
```
|
||||||
|
|
||||||
## VJEPA2Config
|
## VJEPA2Config
|
||||||
|
|
||||||
[[autodoc]] VJEPA2Config
|
[[autodoc]] VJEPA2Config
|
||||||
@@ -77,6 +114,11 @@ predictor_outputs = outputs.predictor_output.last_hidden_state
|
|||||||
[[autodoc]] VJEPA2Model
|
[[autodoc]] VJEPA2Model
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## VJEPA2ForVideoClassification
|
||||||
|
|
||||||
|
[[autodoc]] VJEPA2ForVideoClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
## VJEPA2VideoProcessor
|
## VJEPA2VideoProcessor
|
||||||
|
|
||||||
[[autodoc]] VJEPA2VideoProcessor
|
[[autodoc]] VJEPA2VideoProcessor
|
||||||
|
|||||||
@@ -150,6 +150,7 @@ LOSS_MAPPING = {
|
|||||||
"ForQuestionAnswering": ForQuestionAnsweringLoss,
|
"ForQuestionAnswering": ForQuestionAnsweringLoss,
|
||||||
"ForSequenceClassification": ForSequenceClassificationLoss,
|
"ForSequenceClassification": ForSequenceClassificationLoss,
|
||||||
"ForImageClassification": ForSequenceClassificationLoss,
|
"ForImageClassification": ForSequenceClassificationLoss,
|
||||||
|
"ForVideoClassification": ForSequenceClassificationLoss,
|
||||||
"ForTokenClassification": ForTokenClassification,
|
"ForTokenClassification": ForTokenClassification,
|
||||||
"ForSegmentation": ForSegmentationLoss,
|
"ForSegmentation": ForSegmentationLoss,
|
||||||
"ForObjectDetection": ForObjectDetectionLoss,
|
"ForObjectDetection": ForObjectDetectionLoss,
|
||||||
|
|||||||
@@ -844,6 +844,7 @@ MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("timesformer", "TimesformerForVideoClassification"),
|
("timesformer", "TimesformerForVideoClassification"),
|
||||||
("videomae", "VideoMAEForVideoClassification"),
|
("videomae", "VideoMAEForVideoClassification"),
|
||||||
("vivit", "VivitForVideoClassification"),
|
("vivit", "VivitForVideoClassification"),
|
||||||
|
("vjepa2", "VJEPA2ForVideoClassification"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -60,6 +60,10 @@ class VJEPA2Config(PretrainedConfig):
|
|||||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout probability for attentions.
|
||||||
|
num_pooler_layers (`int`, *optional*, defaults to 3):
|
||||||
|
The number of self-attention layers in the pooler.
|
||||||
pred_hidden_size (`int`, *optional*, defaults to 384):
|
pred_hidden_size (`int`, *optional*, defaults to 384):
|
||||||
Dimensionality of the predictor layers
|
Dimensionality of the predictor layers
|
||||||
pred_num_attention_heads (`int`, *optional*, defaults to 12):
|
pred_num_attention_heads (`int`, *optional*, defaults to 12):
|
||||||
@@ -107,6 +111,8 @@ class VJEPA2Config(PretrainedConfig):
|
|||||||
attention_probs_dropout_prob=0.0,
|
attention_probs_dropout_prob=0.0,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
num_pooler_layers=3,
|
||||||
# predictor params
|
# predictor params
|
||||||
pred_hidden_size=384,
|
pred_hidden_size=384,
|
||||||
pred_num_attention_heads=12,
|
pred_num_attention_heads=12,
|
||||||
@@ -134,6 +140,8 @@ class VJEPA2Config(PretrainedConfig):
|
|||||||
self.hidden_act = hidden_act
|
self.hidden_act = hidden_act
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.image_size = crop_size
|
self.image_size = crop_size
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.num_pooler_layers = num_pooler_layers
|
||||||
# predictor params
|
# predictor params
|
||||||
self.pred_hidden_size = pred_hidden_size
|
self.pred_hidden_size = pred_hidden_size
|
||||||
self.pred_num_attention_heads = pred_num_attention_heads
|
self.pred_num_attention_heads = pred_num_attention_heads
|
||||||
|
|||||||
@@ -0,0 +1,220 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from decord import VideoReader
|
||||||
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
|
|
||||||
|
from transformers import VJEPA2ForVideoClassification, VJEPA2VideoProcessor
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def get_video():
|
||||||
|
path = hf_hub_download(
|
||||||
|
repo_id="nateraw/kinetics-mini",
|
||||||
|
filename="val/bowling/-WH-lxmGJVY_000005_000015.mp4",
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
video_reader = VideoReader(path)
|
||||||
|
return video_reader
|
||||||
|
|
||||||
|
|
||||||
|
CLASSIFIERS = {
|
||||||
|
# Something-Something-v2 dataset
|
||||||
|
"vjepa2-vitl-fpc16-256-ssv2": {
|
||||||
|
"base_model": "facebook/vjepa2-vitl-fpc64-256",
|
||||||
|
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitl-16x2x3.pt",
|
||||||
|
"num_labels": 174,
|
||||||
|
"frames_per_clip": 16,
|
||||||
|
"dataset": "something-something-v2",
|
||||||
|
"result": (145, 0.30867, "Stuffing [something] into [something]"),
|
||||||
|
},
|
||||||
|
"vjepa2-vitg-fpc64-384-ssv2": {
|
||||||
|
"base_model": "facebook/vjepa2-vitg-fpc64-384",
|
||||||
|
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt",
|
||||||
|
"frames_per_clip": 64,
|
||||||
|
"num_labels": 174,
|
||||||
|
"dataset": "something-something-v2",
|
||||||
|
"result": (112, 0.26408, "Putting [something] onto [something]"),
|
||||||
|
},
|
||||||
|
# Diving48 dataset
|
||||||
|
"vjepa2-vitl-fpc32-256-diving48": {
|
||||||
|
"base_model": "facebook/vjepa2-vitl-fpc64-256",
|
||||||
|
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitl-256.pt",
|
||||||
|
"num_labels": 48,
|
||||||
|
"frames_per_clip": 32,
|
||||||
|
"dataset": "diving48",
|
||||||
|
"result": (35, 0.32875, "['Inward', '35som', 'NoTwis', 'TUCK']"),
|
||||||
|
},
|
||||||
|
"vjepa2-vitg-fpc32-384-diving48": {
|
||||||
|
"base_model": "facebook/vjepa2-vitg-fpc64-384",
|
||||||
|
"checkpoint": "https://dl.fbaipublicfiles.com/vjepa2/evals/diving48-vitg-384-32x4x3.pt",
|
||||||
|
"frames_per_clip": 32,
|
||||||
|
"num_labels": 48,
|
||||||
|
"dataset": "diving48",
|
||||||
|
"result": (22, 0.35351, "['Forward', '25som', '2Twis', 'PIKE']"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||||
|
r"module.pooler.query_tokens": r"pooler.query_tokens",
|
||||||
|
r"module.pooler.cross_attention_block.norm(\d+).": r"pooler.cross_attention_layer.layer_norm\1.",
|
||||||
|
r"module.pooler.cross_attention_block.xattn.(q|k|v).": r"pooler.cross_attention_layer.cross_attn.\1_proj.",
|
||||||
|
r"module.pooler.cross_attention_block.mlp.fc(\d+).": r"pooler.cross_attention_layer.mlp.fc\1.",
|
||||||
|
r"module.pooler.blocks.(\d+).norm(\d+).": r"pooler.self_attention_layers.\1.layer_norm\2.",
|
||||||
|
r"module.pooler.blocks.(\d+).attn.(q|k|v).": r"pooler.self_attention_layers.\1.self_attn.\2_proj.",
|
||||||
|
r"module.pooler.blocks.(\d+).attn.proj.": r"pooler.self_attention_layers.\1.self_attn.out_proj.",
|
||||||
|
r"module.pooler.blocks.(\d+).mlp.fc(\d+).": r"pooler.self_attention_layers.\1.mlp.fc\2.",
|
||||||
|
r"module.linear.": r"classifier.",
|
||||||
|
}
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def get_id2label_mapping(dataset_name: str) -> dict[int, str]:
|
||||||
|
path = hf_hub_download(
|
||||||
|
repo_id="huggingface/label-files",
|
||||||
|
filename=f"{dataset_name}-id2label.json",
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
with open(path, "r") as f:
|
||||||
|
id2label = json.load(f)
|
||||||
|
id2label = {int(k): v for k, v in id2label.items()}
|
||||||
|
return id2label
|
||||||
|
|
||||||
|
|
||||||
|
def split_qkv(state_dict):
|
||||||
|
state_dict = state_dict.copy()
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
for key in keys:
|
||||||
|
if ".qkv." in key:
|
||||||
|
tensor = state_dict.pop(key)
|
||||||
|
q, k, v = torch.chunk(tensor, 3, dim=0)
|
||||||
|
state_dict[key.replace(".qkv.", ".q.")] = q
|
||||||
|
state_dict[key.replace(".qkv.", ".k.")] = k
|
||||||
|
state_dict[key.replace(".qkv.", ".v.")] = v
|
||||||
|
elif ".kv." in key:
|
||||||
|
tensor = state_dict.pop(key)
|
||||||
|
k, v = torch.chunk(tensor, 2, dim=0)
|
||||||
|
state_dict[key.replace(".kv.", ".k.")] = k
|
||||||
|
state_dict[key.replace(".kv.", ".v.")] = v
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_old_keys_to_new_keys(state_dict):
|
||||||
|
"""
|
||||||
|
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||||
|
the key mappings.
|
||||||
|
"""
|
||||||
|
output_dict = {}
|
||||||
|
old_text = "\n".join(state_dict)
|
||||||
|
new_text = old_text
|
||||||
|
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||||
|
if replacement is None:
|
||||||
|
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||||
|
continue
|
||||||
|
new_text = re.sub(pattern, replacement, new_text)
|
||||||
|
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||||
|
return output_dict
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
model_params = CLASSIFIERS[args.model_name]
|
||||||
|
id2label = get_id2label_mapping(model_params["dataset"])
|
||||||
|
|
||||||
|
if not len(id2label) == model_params["num_labels"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of labels in id2label mapping ({len(id2label)}) does not "
|
||||||
|
f"match number of labels in model ({model_params['num_labels']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = VJEPA2ForVideoClassification.from_pretrained(
|
||||||
|
model_params["base_model"],
|
||||||
|
num_labels=model_params["num_labels"],
|
||||||
|
id2label=id2label,
|
||||||
|
frames_per_clip=model_params["frames_per_clip"],
|
||||||
|
)
|
||||||
|
processor = VJEPA2VideoProcessor.from_pretrained(model_params["base_model"])
|
||||||
|
|
||||||
|
# load and convert classifier checkpoint
|
||||||
|
checkpoint = torch.hub.load_state_dict_from_url(model_params["checkpoint"])
|
||||||
|
state_dict = checkpoint["classifiers"][0]
|
||||||
|
|
||||||
|
state_dict_qkv_split = split_qkv(state_dict)
|
||||||
|
key_mapping = convert_old_keys_to_new_keys(state_dict_qkv_split.keys())
|
||||||
|
converted_state_dict2 = {key_mapping[k]: v for k, v in state_dict_qkv_split.items()}
|
||||||
|
|
||||||
|
result = model.load_state_dict(converted_state_dict2, strict=False)
|
||||||
|
if result.unexpected_keys:
|
||||||
|
raise ValueError(f"Error loading state dict: {result.unexpected_keys}")
|
||||||
|
|
||||||
|
if not args.skip_verification:
|
||||||
|
# get inputs
|
||||||
|
video_reader = get_video()
|
||||||
|
frame_indexes = np.arange(0, 128, 128 / model_params["frames_per_clip"])
|
||||||
|
video = video_reader.get_batch(frame_indexes).asnumpy()
|
||||||
|
inputs = processor(video, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
# run model
|
||||||
|
model.to(device).eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
# compare results
|
||||||
|
probs = torch.softmax(outputs.logits, dim=-1)
|
||||||
|
top_prob, top_idx = probs.topk(1)
|
||||||
|
top_prob, top_idx = top_prob.item(), top_idx.item()
|
||||||
|
label = id2label[top_idx]
|
||||||
|
expected_id, expected_prob, expected_label = model_params["result"]
|
||||||
|
|
||||||
|
if not top_idx == expected_id:
|
||||||
|
raise ValueError(f"Expected id {expected_id} but got {top_idx}")
|
||||||
|
if not label == expected_label:
|
||||||
|
raise ValueError(f"Expected label {expected_label} but got {label}")
|
||||||
|
if not np.isclose(top_prob, expected_prob, atol=1e-3):
|
||||||
|
raise ValueError(f"Expected prob {expected_prob} but got {top_prob}")
|
||||||
|
print("Verification passed")
|
||||||
|
|
||||||
|
output_dir = os.path.join(args.base_dir, args.model_name)
|
||||||
|
model.save_pretrained(output_dir)
|
||||||
|
processor.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
api = HfApi()
|
||||||
|
repo_id = f"{args.repo_org}/{args.model_name}"
|
||||||
|
if not api.repo_exists(repo_id):
|
||||||
|
api.create_repo(repo_id, repo_type="model")
|
||||||
|
api.upload_folder(folder_path=output_dir, repo_id=repo_id, repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_name", type=str, required=True)
|
||||||
|
parser.add_argument("--base_dir", type=str, default="converted_models/")
|
||||||
|
parser.add_argument("--repo_org", type=str, default="qubvel-hf")
|
||||||
|
parser.add_argument("--push_to_hub", action="store_true")
|
||||||
|
parser.add_argument("--skip_verification", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
@@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -19,7 +20,7 @@ from torch import nn
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_layers import GradientCheckpointingLayer
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutput, dataclass
|
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||||
from .configuration_vjepa2 import VJEPA2Config
|
from .configuration_vjepa2 import VJEPA2Config
|
||||||
@@ -536,17 +537,21 @@ class VJEPA2Encoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_masks(x, masks) -> torch.Tensor:
|
def apply_masks(tensor: torch.Tensor, masks: List[torch.Tensor]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
:param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
|
Args:
|
||||||
:param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
|
tensor (`torch.Tensor`):
|
||||||
|
Tensor of shape [batch_size, num_patches, feature_dim]
|
||||||
|
masks (`List[torch.Tensor]`):
|
||||||
|
List of tensors of shape [batch_size, num_patches] containing indices of patches to keep
|
||||||
"""
|
"""
|
||||||
all_x = []
|
all_masked_tensors = []
|
||||||
for m in masks:
|
for mask in masks:
|
||||||
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
|
mask = mask.to(tensor.device)
|
||||||
all_x += [torch.gather(x, dim=1, index=mask_keep)]
|
mask_keep = mask.unsqueeze(-1).repeat(1, 1, tensor.size(-1))
|
||||||
|
all_masked_tensors += [torch.gather(tensor, dim=1, index=mask_keep)]
|
||||||
|
|
||||||
return torch.cat(all_x, dim=0)
|
return torch.cat(all_masked_tensors, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class VJEPA2PredictorEmbeddings(nn.Module):
|
class VJEPA2PredictorEmbeddings(nn.Module):
|
||||||
@@ -649,13 +654,18 @@ class VJEPA2Predictor(nn.Module):
|
|||||||
self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True)
|
self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True)
|
||||||
|
|
||||||
def sort_tokens(self, hidden_states, position_masks, argsort, head_mask=None):
|
def sort_tokens(self, hidden_states, position_masks, argsort, head_mask=None):
|
||||||
|
# gather position masks
|
||||||
|
argsort = argsort.to(position_masks.device)
|
||||||
position_masks = torch.gather(position_masks, dim=1, index=argsort)
|
position_masks = torch.gather(position_masks, dim=1, index=argsort)
|
||||||
hidden_states = torch.gather(
|
|
||||||
hidden_states,
|
# gather hidden states
|
||||||
dim=1,
|
argsort = argsort.to(hidden_states.device)
|
||||||
index=argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
|
hidden_states_argsort = argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
|
||||||
)
|
hidden_states = torch.gather(hidden_states, dim=1, index=hidden_states_argsort)
|
||||||
|
|
||||||
|
# gather head mask
|
||||||
if head_mask is not None and head_mask[0] is not None:
|
if head_mask is not None and head_mask[0] is not None:
|
||||||
|
argsort = argsort.to(head_mask.device)
|
||||||
head_mask = head_mask.permute(1, 0, 2, 3, 4)
|
head_mask = head_mask.permute(1, 0, 2, 3, 4)
|
||||||
argsort_4d = (
|
argsort_4d = (
|
||||||
argsort.unsqueeze(1)
|
argsort.unsqueeze(1)
|
||||||
@@ -673,15 +683,14 @@ class VJEPA2Predictor(nn.Module):
|
|||||||
)
|
)
|
||||||
head_mask = torch.gather(head_mask, dim=4, index=argsort_5d)
|
head_mask = torch.gather(head_mask, dim=4, index=argsort_5d)
|
||||||
head_mask = head_mask.permute(1, 0, 2, 3, 4)
|
head_mask = head_mask.permute(1, 0, 2, 3, 4)
|
||||||
|
|
||||||
return hidden_states, position_masks, head_mask
|
return hidden_states, position_masks, head_mask
|
||||||
|
|
||||||
def unsort_tokens(self, hidden_states, argsort):
|
def unsort_tokens(self, hidden_states, argsort):
|
||||||
|
argsort = argsort.to(hidden_states.device)
|
||||||
reverse_argsort = torch.argsort(argsort, dim=1)
|
reverse_argsort = torch.argsort(argsort, dim=1)
|
||||||
hidden_states = torch.gather(
|
reverse_argsort = reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
|
||||||
hidden_states,
|
hidden_states = torch.gather(hidden_states, dim=1, index=reverse_argsort)
|
||||||
dim=1,
|
|
||||||
index=reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
|
|
||||||
)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@@ -735,49 +744,304 @@ class VJEPA2Predictor(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2PoolerSelfAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config: VJEPA2Config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
batch_size, seq_length, embed_dim = hidden_states.shape
|
||||||
|
|
||||||
|
queries = self.q_proj(hidden_states)
|
||||||
|
keys = self.k_proj(hidden_states)
|
||||||
|
values = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||||
|
logger.warning_once(
|
||||||
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||||
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
|
attn_output, attn_weights = attention_interface(
|
||||||
|
self,
|
||||||
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
attention_mask,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
scaling=self.scale,
|
||||||
|
dropout=0.0 if not self.training else self.dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2PoolerCrossAttention(nn.Module):
|
||||||
|
"""It's different from other cross-attention layers, doesn't have output projection layer (o_proj)"""
|
||||||
|
|
||||||
|
# in case of modular refactoring - o_proj can be replaces with nn.Identity()
|
||||||
|
|
||||||
|
def __init__(self, config: VJEPA2Config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
queries: torch.Tensor,
|
||||||
|
keys: torch.Tensor,
|
||||||
|
values: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
batch_size, q_seq_length, embed_dim = queries.shape
|
||||||
|
kv_seq_length = keys.shape[1]
|
||||||
|
|
||||||
|
queries = self.q_proj(queries)
|
||||||
|
keys = self.k_proj(keys)
|
||||||
|
values = self.v_proj(values)
|
||||||
|
|
||||||
|
queries = queries.view(batch_size, q_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
keys = keys.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
values = values.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||||
|
logger.warning_once(
|
||||||
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||||
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
|
attn_output, attn_weights = attention_interface(
|
||||||
|
self,
|
||||||
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
attention_mask,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
scaling=self.scale,
|
||||||
|
dropout=0.0 if not self.training else self.dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_seq_length, embed_dim).contiguous()
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from SiglipEncoderLayer, but we have to propagate proper hidden_size to VJEPA2MLP
|
||||||
|
class VJEPA2PoolerSelfAttentionLayer(GradientCheckpointingLayer):
|
||||||
|
def __init__(self, config: VJEPA2Config):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.self_attn = VJEPA2PoolerSelfAttention(config)
|
||||||
|
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`):
|
||||||
|
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||||||
|
attention_mask (`torch.FloatTensor`):
|
||||||
|
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states, attn_weights = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2PoolerCrossAttentionLayer(GradientCheckpointingLayer):
|
||||||
|
def __init__(self, config: VJEPA2Config):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.cross_attn = VJEPA2PoolerCrossAttention(config)
|
||||||
|
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
queries: torch.Tensor,
|
||||||
|
hidden_state: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, ...]:
|
||||||
|
# Apply cross-attention
|
||||||
|
residual = queries
|
||||||
|
hidden_state = self.layer_norm1(hidden_state)
|
||||||
|
hidden_state, *attn_weights = self.cross_attn(
|
||||||
|
queries,
|
||||||
|
hidden_state,
|
||||||
|
hidden_state,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
hidden_state = residual + hidden_state
|
||||||
|
|
||||||
|
# Apply MLP
|
||||||
|
residual = hidden_state
|
||||||
|
hidden_state = self.layer_norm2(hidden_state)
|
||||||
|
hidden_state = self.mlp(hidden_state)
|
||||||
|
hidden_state = residual + hidden_state
|
||||||
|
|
||||||
|
outputs = (hidden_state,)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += tuple(attn_weights)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class VJEPA2AttentivePooler(nn.Module):
|
||||||
|
"""Attentive Pooler"""
|
||||||
|
|
||||||
|
def __init__(self, config: VJEPA2Config):
|
||||||
|
super().__init__()
|
||||||
|
self.query_tokens = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||||
|
self.cross_attention_layer = VJEPA2PoolerCrossAttentionLayer(config)
|
||||||
|
self.self_attention_layers = nn.ModuleList(
|
||||||
|
[VJEPA2PoolerSelfAttentionLayer(config) for _ in range(config.num_pooler_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
for layer in self.self_attention_layers:
|
||||||
|
hidden_state = layer(hidden_state, attention_mask=None)[0]
|
||||||
|
queries = self.query_tokens.repeat(hidden_state.shape[0], 1, 1)
|
||||||
|
hidden_state = self.cross_attention_layer(queries, hidden_state)[0]
|
||||||
|
return hidden_state.squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
class VJEPA2PreTrainedModel(PreTrainedModel):
|
class VJEPA2PreTrainedModel(PreTrainedModel):
|
||||||
config_class = VJEPA2Config
|
config_class = VJEPA2Config
|
||||||
base_model_prefix = "vjepa2"
|
base_model_prefix = "vjepa2"
|
||||||
main_input_name = "pixel_values_videos"
|
main_input_name = "pixel_values_videos"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["VJEPA2Layer"]
|
_no_split_modules = [
|
||||||
|
"VJEPA2Layer",
|
||||||
|
"VJEPA2PoolerSelfAttentionLayer",
|
||||||
|
"VJEPA2PoolerCrossAttentionLayer",
|
||||||
|
"VJEPA2PredictorEmbeddings",
|
||||||
|
]
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
|
|
||||||
def _init_weights(
|
def _init_weights(self, module):
|
||||||
self,
|
|
||||||
module: Union[
|
|
||||||
nn.Linear,
|
|
||||||
nn.Conv2d,
|
|
||||||
nn.LayerNorm,
|
|
||||||
VJEPA2Embeddings,
|
|
||||||
VJEPA2PredictorEmbeddings,
|
|
||||||
],
|
|
||||||
):
|
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
|
|
||||||
|
init_std = self.config.initializer_range
|
||||||
|
|
||||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||||
# `trunc_normal_cpu` not implemented in `half` issues
|
# `trunc_normal_cpu` not implemented in `half` issues
|
||||||
module.weight.data = nn.init.trunc_normal_(
|
def trunc_normal_f32_(weight, std):
|
||||||
module.weight.data.to(torch.float32),
|
data_float_32 = weight.data.to(torch.float32)
|
||||||
mean=0.0,
|
data_init = nn.init.trunc_normal_(data_float_32, mean=0.0, std=std)
|
||||||
std=self.config.initializer_range,
|
weight.data = data_init.to(weight.dtype)
|
||||||
).to(module.weight.dtype)
|
|
||||||
|
if isinstance(module, VJEPA2AttentivePooler):
|
||||||
|
trunc_normal_f32_(module.query_tokens, std=init_std)
|
||||||
|
for i, layer in enumerate(module.self_attention_layers, 1):
|
||||||
|
std = init_std / (i**0.5)
|
||||||
|
trunc_normal_f32_(layer.self_attn.out_proj.weight, std=std)
|
||||||
|
trunc_normal_f32_(layer.mlp.fc2.weight, std=std)
|
||||||
|
std = init_std / (len(module.self_attention_layers) + 1) ** 0.5
|
||||||
|
trunc_normal_f32_(module.cross_attention_layer.mlp.fc2.weight, std=std)
|
||||||
|
elif isinstance(module, VJEPA2PredictorEmbeddings):
|
||||||
|
if module.zero_init_mask_tokens:
|
||||||
|
module.mask_tokens.data.zero_()
|
||||||
|
else:
|
||||||
|
trunc_normal_f32_(module.mask_tokens, std=init_std)
|
||||||
|
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
|
||||||
|
trunc_normal_f32_(module.weight, std=init_std)
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, nn.LayerNorm):
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
module.weight.data.fill_(1.0)
|
module.weight.data.fill_(1.0)
|
||||||
elif isinstance(module, VJEPA2PredictorEmbeddings):
|
|
||||||
if not module.zero_init_mask_tokens:
|
|
||||||
module.mask_token = nn.init.trunc_normal_(
|
|
||||||
module.mask_token.to(torch.float32),
|
|
||||||
mean=0.0,
|
|
||||||
std=self.config.initializer_range,
|
|
||||||
).to(module.mask_token.dtype)
|
|
||||||
else:
|
|
||||||
module.mask_tokens.data.zero_()
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_head_mask_to_5d(head_mask, num_hidden_layers):
|
def _convert_head_mask_to_5d(head_mask, num_hidden_layers):
|
||||||
@@ -900,4 +1164,92 @@ class VJEPA2Model(VJEPA2PreTrainedModel):
|
|||||||
return encoder_output.last_hidden_state
|
return encoder_output.last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel"]
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
V-JEPA 2 Model transformer with a video classification head on top (a linear layer on top of the attentive pooler).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class VJEPA2ForVideoClassification(VJEPA2PreTrainedModel):
|
||||||
|
def __init__(self, config: VJEPA2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.vjepa2 = VJEPA2Model(config)
|
||||||
|
|
||||||
|
# Classifier head
|
||||||
|
self.pooler = VJEPA2AttentivePooler(config)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=True)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values_videos: torch.Tensor,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, ImageClassifierOutput]:
|
||||||
|
r"""
|
||||||
|
pixel_values_videos (`torch.Tensor` with shape `[batch size x num_frames x num_channels x height x width]`):
|
||||||
|
The input video pixels which is processed by VJEPA2VideoProcessor.
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification
|
||||||
|
|
||||||
|
>>> device = "cuda"
|
||||||
|
|
||||||
|
>>> video_processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2")
|
||||||
|
>>> model = VJEPA2ForVideoClassification.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2").to(device)
|
||||||
|
|
||||||
|
>>> video = np.ones((64, 256, 256, 3)) # 64 frames, 256x256 RGB
|
||||||
|
>>> inputs = video_processor(video, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
>>> # For inference
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
... outputs = model(**inputs)
|
||||||
|
>>> logits = outputs.logits
|
||||||
|
|
||||||
|
>>> predicted_label = logits.argmax(-1).item()
|
||||||
|
>>> print(model.config.id2label[predicted_label])
|
||||||
|
|
||||||
|
>>> # For training
|
||||||
|
>>> labels = torch.ones(1, dtype=torch.long, device=device)
|
||||||
|
>>> loss = model(**inputs, labels=labels).loss
|
||||||
|
|
||||||
|
```"""
|
||||||
|
|
||||||
|
outputs = self.vjepa2(
|
||||||
|
pixel_values_videos=pixel_values_videos,
|
||||||
|
skip_predictor=True,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = outputs.last_hidden_state
|
||||||
|
pooler_output = self.pooler(last_hidden_state)
|
||||||
|
logits = self.classifier(pooler_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(pooled_logits=logits, labels=labels, config=self.config)
|
||||||
|
|
||||||
|
return ImageClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel", "VJEPA2ForVideoClassification"]
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ from ..models.auto.modeling_auto import (
|
|||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES,
|
||||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||||
MODEL_MAPPING_NAMES,
|
MODEL_MAPPING_NAMES,
|
||||||
)
|
)
|
||||||
@@ -194,6 +195,7 @@ _SPECIAL_SUPPORTED_MODELS = [
|
|||||||
"TrOCRDecoder",
|
"TrOCRDecoder",
|
||||||
"PeftModelForCausalLM",
|
"PeftModelForCausalLM",
|
||||||
"PeftModelForSeq2SeqLM",
|
"PeftModelForSeq2SeqLM",
|
||||||
|
"VJEPA2ForVideoClassification",
|
||||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||||
# XLNetForQuestionAnswering,
|
# XLNetForQuestionAnswering,
|
||||||
]
|
]
|
||||||
@@ -904,6 +906,7 @@ class HFTracer(Tracer):
|
|||||||
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
|
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
|
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
|
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
|
||||||
|
*get_values(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
|
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
|
||||||
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
|
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
|
||||||
]:
|
]:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import VJEPA2Model
|
from transformers import VJEPA2ForVideoClassification, VJEPA2Model
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -153,7 +153,7 @@ class VJEPA2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
test_torch_exportable = True
|
test_torch_exportable = True
|
||||||
|
|
||||||
all_model_classes = (VJEPA2Model,) if is_torch_available() else ()
|
all_model_classes = (VJEPA2Model, VJEPA2ForVideoClassification) if is_torch_available() else ()
|
||||||
|
|
||||||
fx_compatible = True
|
fx_compatible = True
|
||||||
|
|
||||||
@@ -267,7 +267,7 @@ class VJEPA2ModelIntegrationTest(unittest.TestCase):
|
|||||||
[[-0.0061, -1.8365, 2.7343], [-2.5938, -2.7181, -0.1663], [-1.7993, -2.2430, -1.1388]],
|
[[-0.0061, -1.8365, 2.7343], [-2.5938, -2.7181, -0.1663], [-1.7993, -2.2430, -1.1388]],
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=8e-2, atol=8e-2)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_inference_video(self):
|
def test_inference_video(self):
|
||||||
@@ -343,3 +343,22 @@ class VJEPA2ModelIntegrationTest(unittest.TestCase):
|
|||||||
# verify the last hidden states
|
# verify the last hidden states
|
||||||
expected_shape = torch.Size((1, num_masks, 1024))
|
expected_shape = torch.Size((1, num_masks, 1024))
|
||||||
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
|
self.assertEqual(outputs.predictor_output.last_hidden_state.shape, expected_shape)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_video_classification(self):
|
||||||
|
checkpoint = "facebook/vjepa2-vitl-fpc16-256-ssv2"
|
||||||
|
|
||||||
|
model = VJEPA2ForVideoClassification.from_pretrained(checkpoint).to(torch_device)
|
||||||
|
video_processor = AutoVideoProcessor.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
sample_video = np.ones((16, 3, 256, 256))
|
||||||
|
inputs = video_processor(sample_video, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
self.assertEqual(outputs.logits.shape, (1, 174))
|
||||||
|
|
||||||
|
expected_logits = torch.tensor([0.8814, -0.1195, -0.6389], device=torch_device)
|
||||||
|
resulted_logits = outputs.logits[0, 100:103]
|
||||||
|
torch.testing.assert_close(resulted_logits, expected_logits, rtol=1e-2, atol=1e-2)
|
||||||
|
|||||||
Reference in New Issue
Block a user