Support passing flash_attn_kwargs when gradient_checkpointing is enabled (#37037)

* support passing flash_attn_kwargs when gradient_checkpointing is enabled

* make modeling_deepspeek_v3.py consistent with modular_deepseek_v3.py
This commit is contained in:
efsotr
2025-03-31 16:53:02 +08:00
committed by GitHub
parent bd41b9c1ac
commit 2b4734bd49
29 changed files with 58 additions and 30 deletions

View File

@@ -19,6 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
from ...activations import ACT2FN
@@ -963,7 +964,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,