Support gradient checkpointing in Qwen2VL ViT (#34724)

* Support gradient checkpointing in Qwen2VL ViT

* Enable gradient checkpoint tests for Qwen2VL

* [run-slow] qwen2_vl
This commit is contained in:
Jiahao Li
2024-11-19 19:30:44 +08:00
committed by GitHub
parent 1a0cd69435
commit 0db91c3c8d
2 changed files with 7 additions and 19 deletions

View File

@@ -1000,6 +1000,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
self.merger = PatchMerger( self.merger = PatchMerger(
dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
) )
self.gradient_checkpointing = False
def get_dtype(self) -> torch.dtype: def get_dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype return self.blocks[0].mlp.fc2.weight.dtype
@@ -1046,7 +1047,12 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for blk in self.blocks: for blk in self.blocks:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb
)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
return self.merger(hidden_states) return self.merger(hidden_states)

View File

@@ -285,24 +285,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0) image_grid_thw = torch.cat([image_grid_thw, image_grid_thw], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw) _ = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw)
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Feedforward chunking is not yet supported") @unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass