Fix Qwen2.5-Omni get_chunked_index chunking functionality (#37631)

* fix: qwen2.5 omni modular get_rope_index

* test: add test for qwen2.5 omni rope index (video with audio input)

* style

* expected_position_ids readability

* fix: use spatial_merge_size = 1 in unit test
This commit is contained in:
Kero Liang
2025-04-22 17:15:37 +08:00
committed by GitHub
parent fee1190601
commit 5f791281c3
4 changed files with 150 additions and 11 deletions

View File

@@ -244,7 +244,8 @@ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedMo
- the second chunk contains values >= 1000 and < 2000, and so on.
Parameters:
token_indices (`List[int]`): A monotonically increasing list of token index values.
token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of
token index values.
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
remove_index (`int`) An index id to subtract from `token_indices` before chunking
@@ -257,12 +258,12 @@ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedMo
i, start_idx = 0, 0 # skip bos token
current_chunk = 1
while i < len(token_indices): # skip eos token
if token_indices[0][i] - remove_index >= current_chunk * tokens_per_chunk:
if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk:
yield (start_idx, i)
start_idx = i
current_chunk += 1
i += 1
yield (start_idx, token_indices.shape[1])
yield (start_idx, len(token_indices))
return list(_iter())
@@ -499,8 +500,8 @@ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedMo
)
t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids, t_ntoken_per_chunk, st_idx)
audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids, t_ntoken_per_chunk, st_idx)
video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
sub_len = 0
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None

View File

@@ -1145,7 +1145,8 @@ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedMo
- the second chunk contains values >= 1000 and < 2000, and so on.
Parameters:
token_indices (`List[int]`): A monotonically increasing list of token index values.
token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of
token index values.
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
remove_index (`int`) An index id to subtract from `token_indices` before chunking
@@ -1158,12 +1159,12 @@ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedMo
i, start_idx = 0, 0 # skip bos token
current_chunk = 1
while i < len(token_indices): # skip eos token
if token_indices[0][i] - remove_index >= current_chunk * tokens_per_chunk:
if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk:
yield (start_idx, i)
start_idx = i
current_chunk += 1
i += 1
yield (start_idx, token_indices.shape[1])
yield (start_idx, len(token_indices))
return list(_iter())
@@ -1400,8 +1401,8 @@ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedMo
)
t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids, t_ntoken_per_chunk, st_idx)
audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids, t_ntoken_per_chunk, st_idx)
video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
sub_len = 0
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None

View File

@@ -289,7 +289,7 @@ class Qwen2_5OmniProcessor(ProcessorMixin):
- the second chunk contains values >= 1000 and < 2000, and so on.
Parameters:
token_indices (`List[int]`): A monotonically increasing list of token index values.
token_indices (`np.ndarray`): A monotonically increasing list of token index values.
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
Returns:

View File

@@ -381,6 +381,143 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
def test_custom_4d_attention_mask(self):
pass
def test_get_rope_index_video_with_audio(self):
image_grid_thw = torch.empty((0, 3), dtype=torch.long)
# 3 * 2 * 2 = 12 video tokens
video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long)
# num_audio_tokens = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1
# i.e.: 300 audio_seqlen -> 75 audio tokens
audio_seqlens = torch.tensor([300], dtype=torch.long)
second_per_grids = torch.tensor([1.0], dtype=torch.float)
use_audio_in_video = True
# fmt: off
expected_position_ids = torch.tensor([
[[
0, 1, # text
2, 2, # vision_bos + audio_bos
# video chunk
3, 3, 3, 3,
28, 28, 28, 28,
# audio chunk
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52,
# video chunk
53, 53, 53, 53,
# audio chunk
53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 78, # audio_eos + vision_eos
79, 80, # text
]],
[[
0, 1, # text
2, 2, # vision_bos + audio_bos
# video chunk
3, 3, 4, 4,
3, 3, 4, 4,
# audio chunk
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52,
# video chunk
3, 3, 4, 4,
# audio chunk
53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 78, # audio_eos + vision_eos
79, 80, # text
]],
[[
0, 1, # text
2, 2, # vision_bos + audio_bos
# video chunk
3, 4, 3, 4,
3, 4, 3, 4,
# audio chunk
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52,
# video chunk
3, 4, 3, 4,
# audio chunk
53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 78, # audio_eos + vision_eos
79, 80, # text
]],
], dtype=torch.long)
# fmt: on
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = torch.tensor(
[
[
100,
101,
]
+ [
config.vision_start_token_id,
config.audio_start_token_id,
]
# 1st chunk: 8 video tokens, 50 audio tokens
+ [config.video_token_id] * 2 * 2 * 2
+ [config.audio_token_id] * 50
+
# 2nd chunk: 4 video tokens, 25 audio tokens
[config.video_token_id] * 1 * 2 * 2
+ [config.audio_token_id] * 25
+ [
config.audio_end_token_id,
config.vision_end_token_id,
]
+ [
102,
103,
]
],
dtype=torch.long,
)
model = model_class(config)
position_ids, mrope_position_deltas = model.get_rope_index(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
attention_mask=None,
use_audio_in_video=use_audio_in_video,
audio_seqlens=audio_seqlens,
second_per_grids=second_per_grids,
)
self.assertTrue(torch.equal(position_ids, expected_position_ids))
@require_torch
class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):