Support passing flash_attn_kwargs when gradient_checkpointing is enabled (#37037)

* support passing flash_attn_kwargs when gradient_checkpointing is enabled

* make modeling_deepspeek_v3.py consistent with modular_deepseek_v3.py
This commit is contained in:
efsotr
2025-03-31 16:53:02 +08:00
committed by GitHub
parent bd41b9c1ac
commit 2b4734bd49
29 changed files with 58 additions and 30 deletions

View File

@@ -4,6 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_dummy.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
@@ -544,7 +545,7 @@ class DummyModel(DummyPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,