Support passing flash_attn_kwargs when gradient_checkpointing is enabled (#37037)
* support passing flash_attn_kwargs when gradient_checkpointing is enabled * make modeling_deepspeek_v3.py consistent with modular_deepseek_v3.py
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_dummy.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -544,7 +545,7 @@ class DummyModel(DummyPreTrainedModel):
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
|
||||
Reference in New Issue
Block a user