From 520b9dcb42cef21662c304583368ff6645116a45 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Thu, 10 Jul 2025 16:44:28 +0800 Subject: [PATCH] fix Glm4v batch videos forward (#39172) * changes for video * update modular * change get_video_features * update video token replacement * update modular * add test and fix typo * lint * fix order * lint * fix * remove dependency * lint * lint * remove todo * resize video for test * lint.. * fix test * new a processor for video_test * fix test --- .../models/glm4v/modeling_glm4v.py | 16 ++++- .../models/glm4v/modular_glm4v.py | 50 +++++++++++++--- .../models/glm4v/processing_glm4v.py | 18 ++++-- .../models/glm4v/video_processing_glm4v.py | 4 -- tests/models/glm4v/test_modeling_glm4v.py | 60 ++++++++++++++++++- .../glm4v/test_video_processing_glm4v.py | 2 +- 6 files changed, 127 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 86bf2b56ff..8ab64acfff 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1052,6 +1052,7 @@ class Glm4vModel(Glm4vPreTrainedModel): device=input_ids.device, ) image_index, video_index = 0, 0 + video_group_index = 0 attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] @@ -1081,7 +1082,6 @@ class Glm4vModel(Glm4vPreTrainedModel): llm_pos_ids_list = [] video_frame_num = 1 - for modality_type, start_idx, end_idx in input_type_group: st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 @@ -1125,7 +1125,11 @@ class Glm4vModel(Glm4vPreTrainedModel): w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) - video_index += 1 + video_group_index += 1 + + if video_group_index >= video_grid_thw[video_index][0]: + video_index += 1 + video_group_index = 0 video_frame_num += 1 @@ -1174,7 +1178,13 @@ class Glm4vModel(Glm4vPreTrainedModel): The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames + temp_frames_hw = [] + for t, h, w in video_grid_thw: + repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1) + temp_frames_hw.append(repeated_row) + flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(video_embeds, split_sizes) return video_embeds diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 3f8f593c20..7a5b6efa2e 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -1064,6 +1064,7 @@ class Glm4vModel(Qwen2_5_VLModel): device=input_ids.device, ) image_index, video_index = 0, 0 + video_group_index = 0 attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] @@ -1093,7 +1094,6 @@ class Glm4vModel(Qwen2_5_VLModel): llm_pos_ids_list = [] video_frame_num = 1 - for modality_type, start_idx, end_idx in input_type_group: st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 @@ -1137,7 +1137,11 @@ class Glm4vModel(Qwen2_5_VLModel): w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) - video_index += 1 + video_group_index += 1 + + if video_group_index >= video_grid_thw[video_index][0]: + video_index += 1 + video_group_index = 0 video_frame_num += 1 @@ -1173,6 +1177,30 @@ class Glm4vModel(Qwen2_5_VLModel): return position_ids, mrope_position_deltas + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames + temp_frames_hw = [] + for t, h, w in video_grid_thw: + repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1) + temp_frames_hw.append(repeated_row) + flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw) + split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + video_embeds = torch.split(video_embeds, split_sizes) + return video_embeds + @auto_docstring @can_return_tuple def forward( @@ -1664,32 +1692,38 @@ class Glm4vProcessor(Qwen2_5_VLProcessor): video_index = 0 for i in range(len(text)): while self.video_token in text[i]: - num_frames = len(video_grid_thw) + num_frames = video_grid_thw[video_index][0] video_structure = "" if hasattr(timestamps, "tolist"): timestamps_list = timestamps.tolist()[0] else: timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps + unique_timestamps = [] for idx in range(0, len(timestamps_list)): unique_timestamps.append(timestamps_list[idx]) + selected_timestamps = unique_timestamps[:num_frames] while len(selected_timestamps) < num_frames: selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + for frame_idx in range(num_frames): timestamp_sec = selected_timestamps[frame_idx] frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}" video_structure += frame_structure + text[i] = text[i].replace(self.video_token, video_structure, 1) + num_image_tokens = ( + video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0] + ) + for frame_idx in range(num_frames): + if self.image_token in text[i]: + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + video_index += 1 - for frame_idx in range(len(video_grid_thw)): - if self.image_token in text[i]: - num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length - text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) text[i] = text[i].replace("<|placeholder|>", self.image_token) - return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) diff --git a/src/transformers/models/glm4v/processing_glm4v.py b/src/transformers/models/glm4v/processing_glm4v.py index 534f744fc3..c71804fc11 100644 --- a/src/transformers/models/glm4v/processing_glm4v.py +++ b/src/transformers/models/glm4v/processing_glm4v.py @@ -167,32 +167,38 @@ class Glm4vProcessor(ProcessorMixin): video_index = 0 for i in range(len(text)): while self.video_token in text[i]: - num_frames = len(video_grid_thw) + num_frames = video_grid_thw[video_index][0] video_structure = "" if hasattr(timestamps, "tolist"): timestamps_list = timestamps.tolist()[0] else: timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps + unique_timestamps = [] for idx in range(0, len(timestamps_list)): unique_timestamps.append(timestamps_list[idx]) + selected_timestamps = unique_timestamps[:num_frames] while len(selected_timestamps) < num_frames: selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + for frame_idx in range(num_frames): timestamp_sec = selected_timestamps[frame_idx] frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}" video_structure += frame_structure + text[i] = text[i].replace(self.video_token, video_structure, 1) + num_image_tokens = ( + video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0] + ) + for frame_idx in range(num_frames): + if self.image_token in text[i]: + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + video_index += 1 - for frame_idx in range(len(video_grid_thw)): - if self.image_token in text[i]: - num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length - text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) text[i] = text[i].replace("<|placeholder|>", self.image_token) - return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) diff --git a/src/transformers/models/glm4v/video_processing_glm4v.py b/src/transformers/models/glm4v/video_processing_glm4v.py index 4f3ffed27e..68915bf586 100644 --- a/src/transformers/models/glm4v/video_processing_glm4v.py +++ b/src/transformers/models/glm4v/video_processing_glm4v.py @@ -249,10 +249,6 @@ class Glm4vVideoProcessor(BaseVideoProcessor): processed_grids = reorder_videos(processed_grids, grouped_videos_index) pixel_values_videos = torch.cat(processed_videos, dim=0) video_grid_thw = torch.tensor(processed_grids) - total_frames = video_grid_thw[0][0].item() - h = video_grid_thw[0][1].item() - w = video_grid_thw[0][2].item() - video_grid_thw = [[1, h, w] for _ in range(total_frames)] data = { "pixel_values_videos": pixel_values_videos, "video_grid_thw": video_grid_thw, diff --git a/tests/models/glm4v/test_modeling_glm4v.py b/tests/models/glm4v/test_modeling_glm4v.py index e5211e59f2..1e962e4c33 100644 --- a/tests/models/glm4v/test_modeling_glm4v.py +++ b/tests/models/glm4v/test_modeling_glm4v.py @@ -13,6 +13,7 @@ # limitations under the License. """Testing suite for the PyTorch GLM-4.1V model.""" +import copy import gc import unittest @@ -236,7 +237,26 @@ class Glm4vModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) def test_generate_from_inputs_embeds_with_static_cache(self): pass - # RoPE index doesn't match when using embeddings + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + del inputs["image_grid_thw"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + with torch.no_grad(): + model(**inputs)[0] + def test_inputs_embeds_matches_input_ids(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -350,6 +370,44 @@ class Glm4vIntegrationTest(unittest.TestCase): EXPECTED_DECODED_TEXT, ) + @slow + def test_small_model_integration_test_with_video(self): + processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking", max_image_size={"longest_edge": 50176}) + model = Glm4vForConditionalGeneration.from_pretrained( + "THUDM/GLM-4.1V-9B-Thinking", torch_dtype=torch.float16, device_map="auto" + ) + questions = ["Describe this video."] * 2 + video_urls = [ + "https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4" + ] * 2 + messages = [ + [ + { + "role": "user", + "content": [ + { + "type": "video", + "video": video_url, + }, + {"type": "text", "text": question}, + ], + } + ] + for question, video_url in zip(questions, video_urls) + ] + inputs = processor.apply_chat_template( + messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", padding=True + ).to(torch_device) + output = model.generate(**inputs, max_new_tokens=30) + EXPECTED_DECODED_TEXT = [ + "\n012345Describe this video.\nGot it, let's analyze the video. First, the scene is a room with a wooden floor, maybe a traditional Japanese room with tatami", + "\n012345Describe this video.\nGot it, let's analyze the video. First, the scene is a room with a wooden floor, maybe a traditional Japanese room with tatami" + ] # fmt: skip + self.assertEqual( + processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + @slow def test_small_model_integration_test_expand(self): model = Glm4vForConditionalGeneration.from_pretrained( diff --git a/tests/models/glm4v/test_video_processing_glm4v.py b/tests/models/glm4v/test_video_processing_glm4v.py index 717b853ac2..b629e61eb5 100644 --- a/tests/models/glm4v/test_video_processing_glm4v.py +++ b/tests/models/glm4v/test_video_processing_glm4v.py @@ -228,7 +228,7 @@ class Glm4vVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase): expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs) self.assertEqual(list(encoded_videos.shape), expected_output_video_shape) - @unittest.skip("Skip for now, the test needs adjustment fo GLM-4.1V") + @unittest.skip("Skip for now, the test needs adjustment for GLM-4.1V") def test_call_numpy_4_channels(self): for video_processing_class in self.video_processor_list: # Test that can process videos which have an arbitrary number of channels