Qwen2-VL: fix rope delta calculation (#36013)

* fix rope delats calculation

* add test

* style
This commit is contained in:
Raushan Turganbay
2025-02-04 09:48:29 +01:00
committed by GitHub
parent e284c7e954
commit 5d75a25b03
4 changed files with 38 additions and 3 deletions

View File

@@ -284,6 +284,29 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
def test_forward_with_rope_deltas_cached(self):
"""
Tests that Qwen2-VL computes new rope deltas every forward pass with new set of inputs.
Rope deltas are cached when we generate and re-used for decoding phase, byt are not reset
automatically after generation ends. See https://github.com/huggingface/transformers/pull/36013 for more
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
# Generate and make sure rope_deltas are not `None`
self.assertTrue(model.rope_deltas is None)
generation_output = model.generate(
**input_dict, max_new_tokens=4, return_dict_in_generate=True, output_logits=True
)
self.assertTrue(model.rope_deltas is not None)
# Now if we try to do forward pass, we should get new rope logits, because cache is not passed
forward_output = model(**input_dict)
torch.testing.assert_close(
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
)
@unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self):
pass