fix Glm4v batch videos forward (#39172)
* changes for video * update modular * change get_video_features * update video token replacement * update modular * add test and fix typo * lint * fix order * lint * fix * remove dependency * lint * lint * remove todo * resize video for test * lint.. * fix test * new a processor for video_test * fix test
This commit is contained in:
@@ -1052,6 +1052,7 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
|||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
)
|
)
|
||||||
image_index, video_index = 0, 0
|
image_index, video_index = 0, 0
|
||||||
|
video_group_index = 0
|
||||||
attention_mask = attention_mask.to(total_input_ids.device)
|
attention_mask = attention_mask.to(total_input_ids.device)
|
||||||
for i, input_ids in enumerate(total_input_ids):
|
for i, input_ids in enumerate(total_input_ids):
|
||||||
input_ids = input_ids[attention_mask[i] == 1]
|
input_ids = input_ids[attention_mask[i] == 1]
|
||||||
@@ -1081,7 +1082,6 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
|||||||
|
|
||||||
llm_pos_ids_list = []
|
llm_pos_ids_list = []
|
||||||
video_frame_num = 1
|
video_frame_num = 1
|
||||||
|
|
||||||
for modality_type, start_idx, end_idx in input_type_group:
|
for modality_type, start_idx, end_idx in input_type_group:
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
|
||||||
@@ -1125,7 +1125,11 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
|||||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
||||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||||
|
|
||||||
video_index += 1
|
video_group_index += 1
|
||||||
|
|
||||||
|
if video_group_index >= video_grid_thw[video_index][0]:
|
||||||
|
video_index += 1
|
||||||
|
video_group_index = 0
|
||||||
|
|
||||||
video_frame_num += 1
|
video_frame_num += 1
|
||||||
|
|
||||||
@@ -1174,7 +1178,13 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
|||||||
The temporal, height and width of feature shape of each video in LLM.
|
The temporal, height and width of feature shape of each video in LLM.
|
||||||
"""
|
"""
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
|
||||||
|
temp_frames_hw = []
|
||||||
|
for t, h, w in video_grid_thw:
|
||||||
|
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
|
||||||
|
temp_frames_hw.append(repeated_row)
|
||||||
|
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
|
||||||
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
||||||
video_embeds = torch.split(video_embeds, split_sizes)
|
video_embeds = torch.split(video_embeds, split_sizes)
|
||||||
return video_embeds
|
return video_embeds
|
||||||
|
|||||||
@@ -1064,6 +1064,7 @@ class Glm4vModel(Qwen2_5_VLModel):
|
|||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
)
|
)
|
||||||
image_index, video_index = 0, 0
|
image_index, video_index = 0, 0
|
||||||
|
video_group_index = 0
|
||||||
attention_mask = attention_mask.to(total_input_ids.device)
|
attention_mask = attention_mask.to(total_input_ids.device)
|
||||||
for i, input_ids in enumerate(total_input_ids):
|
for i, input_ids in enumerate(total_input_ids):
|
||||||
input_ids = input_ids[attention_mask[i] == 1]
|
input_ids = input_ids[attention_mask[i] == 1]
|
||||||
@@ -1093,7 +1094,6 @@ class Glm4vModel(Qwen2_5_VLModel):
|
|||||||
|
|
||||||
llm_pos_ids_list = []
|
llm_pos_ids_list = []
|
||||||
video_frame_num = 1
|
video_frame_num = 1
|
||||||
|
|
||||||
for modality_type, start_idx, end_idx in input_type_group:
|
for modality_type, start_idx, end_idx in input_type_group:
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
|
||||||
@@ -1137,7 +1137,11 @@ class Glm4vModel(Qwen2_5_VLModel):
|
|||||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
||||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||||
|
|
||||||
video_index += 1
|
video_group_index += 1
|
||||||
|
|
||||||
|
if video_group_index >= video_grid_thw[video_index][0]:
|
||||||
|
video_index += 1
|
||||||
|
video_group_index = 0
|
||||||
|
|
||||||
video_frame_num += 1
|
video_frame_num += 1
|
||||||
|
|
||||||
@@ -1173,6 +1177,30 @@ class Glm4vModel(Qwen2_5_VLModel):
|
|||||||
|
|
||||||
return position_ids, mrope_position_deltas
|
return position_ids, mrope_position_deltas
|
||||||
|
|
||||||
|
def get_video_features(
|
||||||
|
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input videos.
|
||||||
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||||
|
The temporal, height and width of feature shape of each video in LLM.
|
||||||
|
"""
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||||
|
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
|
||||||
|
temp_frames_hw = []
|
||||||
|
for t, h, w in video_grid_thw:
|
||||||
|
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
|
||||||
|
temp_frames_hw.append(repeated_row)
|
||||||
|
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
|
||||||
|
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
||||||
|
video_embeds = torch.split(video_embeds, split_sizes)
|
||||||
|
return video_embeds
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1664,32 +1692,38 @@ class Glm4vProcessor(Qwen2_5_VLProcessor):
|
|||||||
video_index = 0
|
video_index = 0
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
while self.video_token in text[i]:
|
while self.video_token in text[i]:
|
||||||
num_frames = len(video_grid_thw)
|
num_frames = video_grid_thw[video_index][0]
|
||||||
video_structure = ""
|
video_structure = ""
|
||||||
|
|
||||||
if hasattr(timestamps, "tolist"):
|
if hasattr(timestamps, "tolist"):
|
||||||
timestamps_list = timestamps.tolist()[0]
|
timestamps_list = timestamps.tolist()[0]
|
||||||
else:
|
else:
|
||||||
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
|
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
|
||||||
|
|
||||||
unique_timestamps = []
|
unique_timestamps = []
|
||||||
for idx in range(0, len(timestamps_list)):
|
for idx in range(0, len(timestamps_list)):
|
||||||
unique_timestamps.append(timestamps_list[idx])
|
unique_timestamps.append(timestamps_list[idx])
|
||||||
|
|
||||||
selected_timestamps = unique_timestamps[:num_frames]
|
selected_timestamps = unique_timestamps[:num_frames]
|
||||||
while len(selected_timestamps) < num_frames:
|
while len(selected_timestamps) < num_frames:
|
||||||
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
||||||
|
|
||||||
for frame_idx in range(num_frames):
|
for frame_idx in range(num_frames):
|
||||||
timestamp_sec = selected_timestamps[frame_idx]
|
timestamp_sec = selected_timestamps[frame_idx]
|
||||||
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
|
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
|
||||||
video_structure += frame_structure
|
video_structure += frame_structure
|
||||||
|
|
||||||
text[i] = text[i].replace(self.video_token, video_structure, 1)
|
text[i] = text[i].replace(self.video_token, video_structure, 1)
|
||||||
|
num_image_tokens = (
|
||||||
|
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
|
||||||
|
)
|
||||||
|
for frame_idx in range(num_frames):
|
||||||
|
if self.image_token in text[i]:
|
||||||
|
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
||||||
|
|
||||||
video_index += 1
|
video_index += 1
|
||||||
|
|
||||||
for frame_idx in range(len(video_grid_thw)):
|
|
||||||
if self.image_token in text[i]:
|
|
||||||
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
|
|
||||||
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
|
||||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
||||||
|
|||||||
@@ -167,32 +167,38 @@ class Glm4vProcessor(ProcessorMixin):
|
|||||||
video_index = 0
|
video_index = 0
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
while self.video_token in text[i]:
|
while self.video_token in text[i]:
|
||||||
num_frames = len(video_grid_thw)
|
num_frames = video_grid_thw[video_index][0]
|
||||||
video_structure = ""
|
video_structure = ""
|
||||||
|
|
||||||
if hasattr(timestamps, "tolist"):
|
if hasattr(timestamps, "tolist"):
|
||||||
timestamps_list = timestamps.tolist()[0]
|
timestamps_list = timestamps.tolist()[0]
|
||||||
else:
|
else:
|
||||||
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
|
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
|
||||||
|
|
||||||
unique_timestamps = []
|
unique_timestamps = []
|
||||||
for idx in range(0, len(timestamps_list)):
|
for idx in range(0, len(timestamps_list)):
|
||||||
unique_timestamps.append(timestamps_list[idx])
|
unique_timestamps.append(timestamps_list[idx])
|
||||||
|
|
||||||
selected_timestamps = unique_timestamps[:num_frames]
|
selected_timestamps = unique_timestamps[:num_frames]
|
||||||
while len(selected_timestamps) < num_frames:
|
while len(selected_timestamps) < num_frames:
|
||||||
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
||||||
|
|
||||||
for frame_idx in range(num_frames):
|
for frame_idx in range(num_frames):
|
||||||
timestamp_sec = selected_timestamps[frame_idx]
|
timestamp_sec = selected_timestamps[frame_idx]
|
||||||
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
|
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
|
||||||
video_structure += frame_structure
|
video_structure += frame_structure
|
||||||
|
|
||||||
text[i] = text[i].replace(self.video_token, video_structure, 1)
|
text[i] = text[i].replace(self.video_token, video_structure, 1)
|
||||||
|
num_image_tokens = (
|
||||||
|
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
|
||||||
|
)
|
||||||
|
for frame_idx in range(num_frames):
|
||||||
|
if self.image_token in text[i]:
|
||||||
|
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
||||||
|
|
||||||
video_index += 1
|
video_index += 1
|
||||||
|
|
||||||
for frame_idx in range(len(video_grid_thw)):
|
|
||||||
if self.image_token in text[i]:
|
|
||||||
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
|
|
||||||
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
|
||||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
||||||
|
|||||||
@@ -249,10 +249,6 @@ class Glm4vVideoProcessor(BaseVideoProcessor):
|
|||||||
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
|
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
|
||||||
pixel_values_videos = torch.cat(processed_videos, dim=0)
|
pixel_values_videos = torch.cat(processed_videos, dim=0)
|
||||||
video_grid_thw = torch.tensor(processed_grids)
|
video_grid_thw = torch.tensor(processed_grids)
|
||||||
total_frames = video_grid_thw[0][0].item()
|
|
||||||
h = video_grid_thw[0][1].item()
|
|
||||||
w = video_grid_thw[0][2].item()
|
|
||||||
video_grid_thw = [[1, h, w] for _ in range(total_frames)]
|
|
||||||
data = {
|
data = {
|
||||||
"pixel_values_videos": pixel_values_videos,
|
"pixel_values_videos": pixel_values_videos,
|
||||||
"video_grid_thw": video_grid_thw,
|
"video_grid_thw": video_grid_thw,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Testing suite for the PyTorch GLM-4.1V model."""
|
"""Testing suite for the PyTorch GLM-4.1V model."""
|
||||||
|
|
||||||
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -236,7 +237,26 @@ class Glm4vModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
|||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# RoPE index doesn't match when using embeddings
|
def test_inputs_embeds(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
del inputs["input_ids"]
|
||||||
|
del inputs["pixel_values"]
|
||||||
|
del inputs["image_grid_thw"]
|
||||||
|
|
||||||
|
wte = model.get_input_embeddings()
|
||||||
|
inputs["inputs_embeds"] = wte(input_ids)
|
||||||
|
with torch.no_grad():
|
||||||
|
model(**inputs)[0]
|
||||||
|
|
||||||
def test_inputs_embeds_matches_input_ids(self):
|
def test_inputs_embeds_matches_input_ids(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@@ -350,6 +370,44 @@ class Glm4vIntegrationTest(unittest.TestCase):
|
|||||||
EXPECTED_DECODED_TEXT,
|
EXPECTED_DECODED_TEXT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_small_model_integration_test_with_video(self):
|
||||||
|
processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking", max_image_size={"longest_edge": 50176})
|
||||||
|
model = Glm4vForConditionalGeneration.from_pretrained(
|
||||||
|
"THUDM/GLM-4.1V-9B-Thinking", torch_dtype=torch.float16, device_map="auto"
|
||||||
|
)
|
||||||
|
questions = ["Describe this video."] * 2
|
||||||
|
video_urls = [
|
||||||
|
"https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4"
|
||||||
|
] * 2
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video",
|
||||||
|
"video": video_url,
|
||||||
|
},
|
||||||
|
{"type": "text", "text": question},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
for question, video_url in zip(questions, video_urls)
|
||||||
|
]
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", padding=True
|
||||||
|
).to(torch_device)
|
||||||
|
output = model.generate(**inputs, max_new_tokens=30)
|
||||||
|
EXPECTED_DECODED_TEXT = [
|
||||||
|
"\n012345Describe this video.\n<think>Got it, let's analyze the video. First, the scene is a room with a wooden floor, maybe a traditional Japanese room with tatami",
|
||||||
|
"\n012345Describe this video.\n<think>Got it, let's analyze the video. First, the scene is a room with a wooden floor, maybe a traditional Japanese room with tatami"
|
||||||
|
] # fmt: skip
|
||||||
|
self.assertEqual(
|
||||||
|
processor.batch_decode(output, skip_special_tokens=True),
|
||||||
|
EXPECTED_DECODED_TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_small_model_integration_test_expand(self):
|
def test_small_model_integration_test_expand(self):
|
||||||
model = Glm4vForConditionalGeneration.from_pretrained(
|
model = Glm4vForConditionalGeneration.from_pretrained(
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ class Glm4vVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
|||||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||||
|
|
||||||
@unittest.skip("Skip for now, the test needs adjustment fo GLM-4.1V")
|
@unittest.skip("Skip for now, the test needs adjustment for GLM-4.1V")
|
||||||
def test_call_numpy_4_channels(self):
|
def test_call_numpy_4_channels(self):
|
||||||
for video_processing_class in self.video_processor_list:
|
for video_processing_class in self.video_processor_list:
|
||||||
# Test that can process videos which have an arbitrary number of channels
|
# Test that can process videos which have an arbitrary number of channels
|
||||||
|
|||||||
Reference in New Issue
Block a user