Support gradient checkpointing in Qwen2VL ViT (#34724)
* Support gradient checkpointing in Qwen2VL ViT * Enable gradient checkpoint tests for Qwen2VL * [run-slow] qwen2_vl
This commit is contained in:
@@ -1000,6 +1000,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
self.merger = PatchMerger(
|
||||
dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def get_dtype(self) -> torch.dtype:
|
||||
return self.blocks[0].mlp.fc2.weight.dtype
|
||||
@@ -1046,6 +1047,11 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb
|
||||
)
|
||||
else:
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
||||
|
||||
return self.merger(hidden_states)
|
||||
|
||||
@@ -285,24 +285,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user