Support gradient checkpointing in Qwen2VL ViT (#34724)

* Support gradient checkpointing in Qwen2VL ViT

* Enable gradient checkpoint tests for Qwen2VL

* [run-slow] qwen2_vl
This commit is contained in:
Jiahao Li
2024-11-19 19:30:44 +08:00
committed by GitHub
parent 1a0cd69435
commit 0db91c3c8d
2 changed files with 7 additions and 19 deletions

View File

@@ -285,24 +285,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self):
pass