From 0f49deacbff3e57cde45222842c0db6375e4fa43 Mon Sep 17 00:00:00 2001 From: laurentd-lunit <103402801+laurentd-lunit@users.noreply.github.com> Date: Tue, 15 Oct 2024 23:19:18 +0900 Subject: [PATCH] [feat] LlavaNext add feature size check to avoid CUDA Runtime Error (#33608) * [feat] add feature size check to avoid CUDA Runtime Error * [minor] add error handling to all llava models * [minor] avoid nested if else * [minor] add error message to Qwen2-vl and chameleon * [fix] token dimension for check * [minor] add feature dim check for videos too * [fix] dimension check * [fix] test reference values --------- Co-authored-by: Raushan Turganbay --- .../models/chameleon/modeling_chameleon.py | 6 ++++++ src/transformers/models/llava/modeling_llava.py | 6 ++++++ .../models/llava_next/modeling_llava_next.py | 6 ++++++ .../llava_next_video/modeling_llava_next_video.py | 12 ++++++++++++ .../llava_next_video/modular_llava_next_video.py | 12 ++++++++++++ .../llava_onevision/modeling_llava_onevision.py | 14 ++++++++++++-- .../models/qwen2_vl/modeling_qwen2_vl.py | 12 ++++++++++++ .../models/video_llava/modeling_video_llava.py | 13 ++++++++++++- .../models/vipllava/modeling_vipllava.py | 6 ++++++ tests/models/llava/test_modeling_llava.py | 4 ++-- tests/models/vipllava/test_modeling_vipllava.py | 4 ++-- 11 files changed, 88 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index fd76c0b115..20dbfc317e 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1287,6 +1287,12 @@ class ChameleonModel(ChameleonPreTrainedModel): if pixel_values is not None: image_tokens = self.get_image_tokens(pixel_values) + n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item() + n_image_features = image_tokens.shape[0] + if n_image_tokens_in_text != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" + ) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index e793ca61c7..411b96f5c5 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -518,6 +518,12 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): # TODO: @raushan retain only the new behavior after v4.47 else: + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 705821c2b7..75dfcf5393 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -895,6 +895,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi # TODO: @raushan retain only the new behavior after v4.47 else: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 7df4cf2037..30257b8439 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -967,6 +967,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene # TODO: @raushan retain only the new behavior after v4.47 else: if image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) @@ -976,6 +982,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_features is not None: + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) special_image_mask = ( (input_ids == self.config.video_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 4b6be407dc..e7de66de44 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -482,6 +482,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): # TODO: @raushan retain only the new behavior after v4.47 else: if image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) @@ -491,6 +497,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_features is not None: + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) special_image_mask = ( (input_ids == self.config.video_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index f65c0fe7cf..3eefb517b1 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -619,7 +619,12 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene image_newline=self.image_newline, vision_aspect_ratio=vision_aspect_ratio, ) - + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) @@ -647,7 +652,12 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene image_newline = self.image_newline[None, None, :].repeat(batch_size, 1, 1).to(video_features.device) video_features = torch.cat((video_features, image_newline), dim=1) video_features = video_features.flatten(0, 1) - + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) special_video_mask = ( (input_ids == self.config.video_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 283e38d3a7..e014a6da6b 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1710,6 +1710,12 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): if pixel_values is not None: pixel_values = pixel_values.type(self.visual.get_dtype()) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) image_mask = ( (input_ids == self.config.image_token_id) .unsqueeze(-1) @@ -1722,6 +1728,12 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) video_mask = ( (input_ids == self.config.video_token_id) .unsqueeze(-1) diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 5711433c36..20fa0166b8 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -618,6 +618,12 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi # TODO: @raushan retain only the new behavior after v4.47 else: if image_outputs is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) @@ -626,8 +632,13 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if video_outputs is not None: + n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item() + n_video_features = video_features.shape[1] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) special_image_mask = ( (input_ids == self.config.video_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 26d92b9ac3..7634822847 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -511,6 +511,12 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) # TODO: @raushan retain only the new behavior after v4.47 else: + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index e183c38a59..07415900bb 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -118,8 +118,8 @@ class LlavaVisionText2TextModelTester: self.batch_size = 3 self.num_channels = 3 self.image_size = 336 - self.encoder_seq_length = 231 - self.num_image_tokens = 224 + self.encoder_seq_length = 232 + self.num_image_tokens = 225 self.seq_length = seq_length + self.num_image_tokens def get_config(self): diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index b12f2c30c7..862e144ecd 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -111,8 +111,8 @@ class VipLlavaVisionText2TextModelTester: self.batch_size = 3 self.num_channels = 3 self.image_size = 336 - self.encoder_seq_length = 231 - self.num_image_tokens = 224 + self.encoder_seq_length = 232 + self.num_image_tokens = 225 self.seq_length = seq_length + self.num_image_tokens def get_config(self):