Support passing flash_attn_kwargs when gradient_checkpointing is enabled (#37037)
* support passing flash_attn_kwargs when gradient_checkpointing is enabled * make modeling_deepspeek_v3.py consistent with modular_deepseek_v3.py
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@@ -963,7 +964,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
|
||||
Reference in New Issue
Block a user