Qwen2-VL: fix rope delta calculation (#36013)
* fix rope delats calculation * add test * style
This commit is contained in:
committed by
GitHub
parent
e284c7e954
commit
5d75a25b03
@@ -284,6 +284,29 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
|
||||
|
||||
def test_forward_with_rope_deltas_cached(self):
|
||||
"""
|
||||
Tests that Qwen2-VL computes new rope deltas every forward pass with new set of inputs.
|
||||
Rope deltas are cached when we generate and re-used for decoding phase, byt are not reset
|
||||
automatically after generation ends. See https://github.com/huggingface/transformers/pull/36013 for more
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
|
||||
# Generate and make sure rope_deltas are not `None`
|
||||
self.assertTrue(model.rope_deltas is None)
|
||||
generation_output = model.generate(
|
||||
**input_dict, max_new_tokens=4, return_dict_in_generate=True, output_logits=True
|
||||
)
|
||||
self.assertTrue(model.rope_deltas is not None)
|
||||
|
||||
# Now if we try to do forward pass, we should get new rope logits, because cache is not passed
|
||||
forward_output = model(**input_dict)
|
||||
torch.testing.assert_close(
|
||||
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user