Fix Qwen2.5-Omni get_chunked_index chunking functionality (#37631)
* fix: qwen2.5 omni modular get_rope_index * test: add test for qwen2.5 omni rope index (video with audio input) * style * expected_position_ids readability * fix: use spatial_merge_size = 1 in unit test
This commit is contained in:
@@ -381,6 +381,143 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
def test_get_rope_index_video_with_audio(self):
|
||||
image_grid_thw = torch.empty((0, 3), dtype=torch.long)
|
||||
|
||||
# 3 * 2 * 2 = 12 video tokens
|
||||
video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long)
|
||||
|
||||
# num_audio_tokens = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1
|
||||
# i.e.: 300 audio_seqlen -> 75 audio tokens
|
||||
audio_seqlens = torch.tensor([300], dtype=torch.long)
|
||||
|
||||
second_per_grids = torch.tensor([1.0], dtype=torch.float)
|
||||
|
||||
use_audio_in_video = True
|
||||
|
||||
# fmt: off
|
||||
expected_position_ids = torch.tensor([
|
||||
[[
|
||||
0, 1, # text
|
||||
2, 2, # vision_bos + audio_bos
|
||||
|
||||
# video chunk
|
||||
3, 3, 3, 3,
|
||||
28, 28, 28, 28,
|
||||
|
||||
# audio chunk
|
||||
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
|
||||
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
|
||||
45, 46, 47, 48, 49, 50, 51, 52,
|
||||
|
||||
# video chunk
|
||||
53, 53, 53, 53,
|
||||
|
||||
# audio chunk
|
||||
53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
|
||||
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
|
||||
|
||||
78, 78, # audio_eos + vision_eos
|
||||
79, 80, # text
|
||||
]],
|
||||
[[
|
||||
0, 1, # text
|
||||
2, 2, # vision_bos + audio_bos
|
||||
|
||||
# video chunk
|
||||
3, 3, 4, 4,
|
||||
3, 3, 4, 4,
|
||||
|
||||
# audio chunk
|
||||
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
|
||||
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
|
||||
45, 46, 47, 48, 49, 50, 51, 52,
|
||||
|
||||
# video chunk
|
||||
3, 3, 4, 4,
|
||||
|
||||
# audio chunk
|
||||
53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
|
||||
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
|
||||
|
||||
78, 78, # audio_eos + vision_eos
|
||||
79, 80, # text
|
||||
]],
|
||||
[[
|
||||
0, 1, # text
|
||||
2, 2, # vision_bos + audio_bos
|
||||
|
||||
# video chunk
|
||||
3, 4, 3, 4,
|
||||
3, 4, 3, 4,
|
||||
|
||||
# audio chunk
|
||||
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
|
||||
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
|
||||
45, 46, 47, 48, 49, 50, 51, 52,
|
||||
|
||||
# video chunk
|
||||
3, 4, 3, 4,
|
||||
|
||||
# audio chunk
|
||||
53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
|
||||
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
|
||||
|
||||
78, 78, # audio_eos + vision_eos
|
||||
79, 80, # text
|
||||
]],
|
||||
], dtype=torch.long)
|
||||
# fmt: on
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
100,
|
||||
101,
|
||||
]
|
||||
+ [
|
||||
config.vision_start_token_id,
|
||||
config.audio_start_token_id,
|
||||
]
|
||||
# 1st chunk: 8 video tokens, 50 audio tokens
|
||||
+ [config.video_token_id] * 2 * 2 * 2
|
||||
+ [config.audio_token_id] * 50
|
||||
+
|
||||
# 2nd chunk: 4 video tokens, 25 audio tokens
|
||||
[config.video_token_id] * 1 * 2 * 2
|
||||
+ [config.audio_token_id] * 25
|
||||
+ [
|
||||
config.audio_end_token_id,
|
||||
config.vision_end_token_id,
|
||||
]
|
||||
+ [
|
||||
102,
|
||||
103,
|
||||
]
|
||||
],
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
position_ids, mrope_position_deltas = model.get_rope_index(
|
||||
input_ids=input_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
attention_mask=None,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
audio_seqlens=audio_seqlens,
|
||||
second_per_grids=second_per_grids,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(position_ids, expected_position_ids))
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user