Introduce GradientCheckpointingLayer (#37223)
* GradientCheckpointingLayer * trigger * Move GC layer to a separate file * Update import * Expose and document GC layer * Fix dummy * Apply to llama-based models * Update modulars * Update a few more models for consistency * Update glm4 * Update Janus
This commit is contained in:
committed by
GitHub
parent
413f9bbf80
commit
9167fadab9
@@ -20,6 +20,10 @@ This page lists all the custom layers used by the library, as well as the utilit
|
|||||||
|
|
||||||
Most of those are only useful if you are studying the code of the models in the library.
|
Most of those are only useful if you are studying the code of the models in the library.
|
||||||
|
|
||||||
|
## Layers
|
||||||
|
|
||||||
|
[[autodoc]] GradientCheckpointingLayer
|
||||||
|
|
||||||
## Attention Functions
|
## Attention Functions
|
||||||
|
|
||||||
[[autodoc]] AttentionInterface
|
[[autodoc]] AttentionInterface
|
||||||
|
|||||||
@@ -438,6 +438,7 @@ else:
|
|||||||
]
|
]
|
||||||
|
|
||||||
_import_structure["modeling_flash_attention_utils"] = []
|
_import_structure["modeling_flash_attention_utils"] = []
|
||||||
|
_import_structure["modeling_layers"] = ["GradientCheckpointingLayer"]
|
||||||
_import_structure["modeling_outputs"] = []
|
_import_structure["modeling_outputs"] = []
|
||||||
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
|
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
|
||||||
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
|
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
|
||||||
@@ -911,6 +912,7 @@ if TYPE_CHECKING:
|
|||||||
from .model_debugging_utils import (
|
from .model_debugging_utils import (
|
||||||
model_addition_debugger_context,
|
model_addition_debugger_context,
|
||||||
)
|
)
|
||||||
|
from .modeling_layers import GradientCheckpointingLayer
|
||||||
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from .modeling_utils import AttentionInterface, PreTrainedModel
|
from .modeling_utils import AttentionInterface, PreTrainedModel
|
||||||
|
|
||||||
|
|||||||
48
src/transformers/modeling_layers.py
Normal file
48
src/transformers/modeling_layers.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class GradientCheckpointingLayer(nn.Module):
|
||||||
|
"""Base class for layers with gradient checkpointing.
|
||||||
|
|
||||||
|
This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
|
||||||
|
(`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
|
||||||
|
enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
|
||||||
|
|
||||||
|
Important:
|
||||||
|
|
||||||
|
When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
|
||||||
|
must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> # Correct - hidden_states passed as positional arg
|
||||||
|
>>> out = self.layer(hidden_states, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
>>> # Incorrect - hidden_states passed as keyword arg
|
||||||
|
>>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
gradient_checkpointing = False
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
|
||||||
|
return super().__call__(*args, **kwargs)
|
||||||
@@ -19,7 +19,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
@@ -28,6 +27,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
@@ -590,7 +590,7 @@ class AriaTextAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class AriaTextDecoderLayer(nn.Module):
|
class AriaTextDecoderLayer(GradientCheckpointingLayer):
|
||||||
"""
|
"""
|
||||||
Aria Text Decoder Layer.
|
Aria Text Decoder Layer.
|
||||||
|
|
||||||
@@ -940,30 +940,17 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,6 @@
|
|||||||
# This file is based on the LLama model definition file in transformers
|
# This file is based on the LLama model definition file in transformers
|
||||||
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -38,6 +37,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
|
|||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
@@ -301,7 +301,7 @@ class CohereAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class CohereDecoderLayer(nn.Module):
|
class CohereDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: CohereConfig, layer_idx: int):
|
def __init__(self, config: CohereConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -589,30 +589,17 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from torch import nn
|
|||||||
|
|
||||||
from ...cache_utils import Cache
|
from ...cache_utils import Cache
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import dynamic_rope_update
|
from ...modeling_rope_utils import dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
@@ -209,7 +210,7 @@ class CohereAttention(LlamaAttention):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class CohereDecoderLayer(nn.Module):
|
class CohereDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: CohereConfig, layer_idx: int):
|
def __init__(self, config: CohereConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|||||||
@@ -19,7 +19,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -29,6 +28,7 @@ from ...activations import ACT2FN
|
|||||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
@@ -290,7 +290,7 @@ class Cohere2MLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
class Cohere2DecoderLayer(nn.Module):
|
class Cohere2DecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Cohere2Config, layer_idx: int):
|
def __init__(self, config: Cohere2Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -612,28 +612,16 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
position_embeddings=position_embeddings,
|
||||||
hidden_states,
|
attention_mask=causal_mask,
|
||||||
position_embeddings,
|
past_key_value=past_key_values,
|
||||||
causal_mask,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
**flash_attn_kwargs,
|
||||||
cache_position,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -526,28 +525,16 @@ class Cohere2Model(Gemma2Model):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
position_embeddings=position_embeddings,
|
||||||
hidden_states,
|
attention_mask=causal_mask,
|
||||||
position_embeddings,
|
past_key_value=past_key_values,
|
||||||
causal_mask,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
**flash_attn_kwargs,
|
||||||
cache_position,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
# modular_deepseek_v3.py file directly. One of our CI enforces this.
|
# modular_deepseek_v3.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -18,6 +17,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
@@ -454,7 +454,7 @@ class DeepseekV3Attention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3DecoderLayer(nn.Module):
|
class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -734,30 +734,17 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -38,6 +37,7 @@ from ...modeling_flash_attention_utils import (
|
|||||||
_flash_attention_forward,
|
_flash_attention_forward,
|
||||||
flash_attn_supports_top_left_mask,
|
flash_attn_supports_top_left_mask,
|
||||||
)
|
)
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -526,7 +526,7 @@ DIFFLLAMA_ATTENTION_CLASSES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DiffLlamaDecoderLayer(nn.Module):
|
class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
|
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -837,30 +837,17 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import cached_property, partial
|
from functools import cached_property
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -34,6 +34,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
@@ -248,7 +249,7 @@ class Emu3Attention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class Emu3DecoderLayer(nn.Module):
|
class Emu3DecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Emu3Config, layer_idx: int):
|
def __init__(self, config: Emu3Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -1422,30 +1423,17 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
|
|||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -282,7 +283,7 @@ class GemmaAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class GemmaDecoderLayer(nn.Module):
|
class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -557,29 +558,16 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
decoder_layer.__call__,
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
)
|
||||||
position_embeddings,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -434,29 +434,16 @@ class GemmaModel(LlamaModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
decoder_layer.__call__,
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
)
|
||||||
position_embeddings,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -31,6 +30,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -300,7 +300,7 @@ class GlmRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class GlmDecoderLayer(nn.Module):
|
class GlmDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: GlmConfig, layer_idx: int):
|
def __init__(self, config: GlmConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -572,30 +572,17 @@ class GlmModel(GlmPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -31,6 +30,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -83,7 +83,7 @@ class Glm4MLP(nn.Module):
|
|||||||
return self.down_proj(up_states)
|
return self.down_proj(up_states)
|
||||||
|
|
||||||
|
|
||||||
class Glm4DecoderLayer(nn.Module):
|
class Glm4DecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Glm4Config, layer_idx: int):
|
def __init__(self, config: Glm4Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -580,30 +580,17 @@ class Glm4Model(Glm4PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
from ...cache_utils import Cache
|
from ...cache_utils import Cache
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import CausalLMOutputWithPast
|
from ...modeling_outputs import CausalLMOutputWithPast
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import LossKwargs, logging
|
from ...utils import LossKwargs, logging
|
||||||
@@ -43,7 +43,7 @@ class Glm4MLP(Phi3MLP):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Glm4DecoderLayer(nn.Module):
|
class Glm4DecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Glm4Config, layer_idx: int):
|
def __init__(self, config: Glm4Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|||||||
@@ -19,7 +19,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -31,6 +30,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
@@ -243,7 +243,7 @@ class GraniteMLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
class GraniteDecoderLayer(nn.Module):
|
class GraniteDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: GraniteConfig, layer_idx: int):
|
def __init__(self, config: GraniteConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -572,30 +572,17 @@ class GraniteModel(GranitePreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from functools import partial
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -192,30 +191,17 @@ class GraniteModel(LlamaModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -31,6 +30,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
|
|||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -285,7 +285,7 @@ class HeliumAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class HeliumDecoderLayer(nn.Module):
|
class HeliumDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
|
def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -557,30 +557,17 @@ class HeliumModel(HeliumPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from ...cache_utils import Cache
|
|||||||
from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
|
from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
|
||||||
from ...generation.utils import GenerateDecoderOnlyOutput
|
from ...generation.utils import GenerateDecoderOnlyOutput
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
@@ -429,7 +430,7 @@ class JanusVisionMLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class JanusVisionEncoderLayer(nn.Module):
|
class JanusVisionEncoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: JanusVisionConfig):
|
def __init__(self, config: JanusVisionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@@ -536,19 +537,12 @@ class JanusVisionEncoder(nn.Module):
|
|||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = encoder_layer(
|
||||||
encoder_layer.__call__,
|
hidden_states,
|
||||||
hidden_states,
|
attention_mask,
|
||||||
attention_mask,
|
output_attentions=output_attentions,
|
||||||
output_attentions,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -29,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
|
|||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -290,7 +290,7 @@ class LlamaAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
class LlamaDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: LlamaConfig, layer_idx: int):
|
def __init__(self, config: LlamaConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -562,30 +562,17 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_mistral.py file directly. One of our CI enforces this.
|
# modular_mistral.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -16,6 +15,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -228,7 +228,7 @@ class MistralRMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
class MistralDecoderLayer(nn.Module):
|
class MistralDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: MistralConfig, layer_idx: int):
|
def __init__(self, config: MistralConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -532,30 +532,17 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -34,6 +33,7 @@ from ...modeling_attn_mask_utils import (
|
|||||||
_prepare_4d_attention_mask_for_sdpa,
|
_prepare_4d_attention_mask_for_sdpa,
|
||||||
)
|
)
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
@@ -351,7 +351,7 @@ class MoonshineRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class MoonshineEncoderLayer(nn.Module):
|
class MoonshineEncoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: MoonshineConfig, layer_idx: int):
|
def __init__(self, config: MoonshineConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -410,7 +410,7 @@ class MoonshineEncoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class MoonshineDecoderLayer(nn.Module):
|
class MoonshineDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
|
def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -668,27 +668,14 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = encoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
encoder_layer.__call__,
|
attention_mask=attention_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
attention_mask,
|
output_attentions=output_attentions,
|
||||||
position_ids,
|
position_embeddings=position_embeddings,
|
||||||
None,
|
**flash_attn_kwargs,
|
||||||
output_attentions,
|
)
|
||||||
False,
|
|
||||||
None,
|
|
||||||
position_embeddings,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -912,33 +899,19 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
causal_mask,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_hidden_states,
|
position_ids=position_ids,
|
||||||
position_ids,
|
past_key_value=past_key_values,
|
||||||
past_key_values,
|
output_attentions=output_attentions,
|
||||||
output_attentions,
|
use_cache=use_cache,
|
||||||
use_cache,
|
cache_position=cache_position,
|
||||||
cache_position,
|
position_embeddings=position_embeddings,
|
||||||
position_embeddings,
|
**flash_attn_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -27,6 +26,7 @@ from ...modeling_attn_mask_utils import (
|
|||||||
_prepare_4d_attention_mask_for_sdpa,
|
_prepare_4d_attention_mask_for_sdpa,
|
||||||
)
|
)
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
@@ -428,7 +428,7 @@ class MoonshineEncoderLayer(LlamaDecoderLayer):
|
|||||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
|
||||||
|
|
||||||
|
|
||||||
class MoonshineDecoderLayer(nn.Module):
|
class MoonshineDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
|
def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -686,27 +686,14 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = encoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
encoder_layer.__call__,
|
attention_mask=attention_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
attention_mask,
|
output_attentions=output_attentions,
|
||||||
position_ids,
|
position_embeddings=position_embeddings,
|
||||||
None,
|
**flash_attn_kwargs,
|
||||||
output_attentions,
|
)
|
||||||
False,
|
|
||||||
None,
|
|
||||||
position_embeddings,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -832,33 +819,19 @@ class MoonshineDecoder(LlamaModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
causal_mask,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_hidden_states,
|
position_ids=position_ids,
|
||||||
position_ids,
|
past_key_value=past_key_values,
|
||||||
past_key_values,
|
output_attentions=output_attentions,
|
||||||
output_attentions,
|
use_cache=use_cache,
|
||||||
use_cache,
|
cache_position=cache_position,
|
||||||
cache_position,
|
position_embeddings=position_embeddings,
|
||||||
position_embeddings,
|
**flash_attn_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_olmo.py file directly. One of our CI enforces this.
|
# modular_olmo.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -16,6 +15,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache
|
|||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
@@ -229,7 +229,7 @@ class OlmoAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class OlmoDecoderLayer(nn.Module):
|
class OlmoDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: OlmoConfig, layer_idx: int):
|
def __init__(self, config: OlmoConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -532,30 +532,17 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_olmo2.py file directly. One of our CI enforces this.
|
# modular_olmo2.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -16,6 +15,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
@@ -233,7 +233,7 @@ class Olmo2MLP(nn.Module):
|
|||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
class Olmo2DecoderLayer(nn.Module):
|
class Olmo2DecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Olmo2Config, layer_idx: int):
|
def __init__(self, config: Olmo2Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -538,30 +538,17 @@ class Olmo2Model(Olmo2PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -32,6 +31,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -257,7 +257,7 @@ class Phi3RMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
class Phi3DecoderLayer(nn.Module):
|
class Phi3DecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Phi3Config, layer_idx: int):
|
def __init__(self, config: Phi3Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -587,30 +587,17 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
@@ -154,7 +155,7 @@ class Phi4MultimodalVisionAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class Phi4MultimodalVisionEncoderLayer(nn.Module):
|
class Phi4MultimodalVisionEncoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Phi4MultimodalVisionConfig):
|
def __init__(self, config: Phi4MultimodalVisionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@@ -262,19 +263,12 @@ class Phi4MultimodalVisionEncoder(nn.Module):
|
|||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = encoder_layer(
|
||||||
encoder_layer.__call__,
|
hidden_states,
|
||||||
hidden_states,
|
attention_mask,
|
||||||
attention_mask,
|
output_attentions=output_attentions,
|
||||||
output_attentions,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -1213,14 +1207,7 @@ class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel):
|
|||||||
attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias
|
attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias
|
||||||
|
|
||||||
for layer in self.encoders:
|
for layer in self.encoders:
|
||||||
if self.gradient_checkpointing and self.training:
|
hidden_states = layer(hidden_states, attention_mask)
|
||||||
hidden_states = self._gradient_checkpointing_func(
|
|
||||||
layer.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states = layer(hidden_states, attention_mask)
|
|
||||||
|
|
||||||
if unfolded:
|
if unfolded:
|
||||||
embed_dim = hidden_states.shape[-1]
|
embed_dim = hidden_states.shape[-1]
|
||||||
@@ -1483,7 +1470,7 @@ class Phi4MultimodalAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class Phi4MultimodalDecoderLayer(nn.Module):
|
class Phi4MultimodalDecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Phi4MultimodalConfig, layer_idx: int):
|
def __init__(self, config: Phi4MultimodalConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -1885,30 +1872,17 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
decoder_layer.__call__,
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -1265,14 +1265,7 @@ class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel):
|
|||||||
attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias
|
attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias
|
||||||
|
|
||||||
for layer in self.encoders:
|
for layer in self.encoders:
|
||||||
if self.gradient_checkpointing and self.training:
|
hidden_states = layer(hidden_states, attention_mask)
|
||||||
hidden_states = self._gradient_checkpointing_func(
|
|
||||||
layer.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states = layer(hidden_states, attention_mask)
|
|
||||||
|
|
||||||
if unfolded:
|
if unfolded:
|
||||||
embed_dim = hidden_states.shape[-1]
|
embed_dim = hidden_states.shape[-1]
|
||||||
@@ -1655,30 +1648,17 @@ class Phi4MultimodalModel(Phi3Model, nn.Module):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
decoder_layer.__call__,
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_qwen2.py file directly. One of our CI enforces this.
|
# modular_qwen2.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -16,6 +15,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -236,7 +236,7 @@ class Qwen2RMSNorm(nn.Module):
|
|||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
class Qwen2DecoderLayer(nn.Module):
|
class Qwen2DecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Qwen2Config, layer_idx: int):
|
def __init__(self, config: Qwen2Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -545,30 +545,17 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -31,6 +30,7 @@ from ...generation import GenerationMixin
|
|||||||
from ...integrations import use_kernel_forward_from_hub
|
from ...integrations import use_kernel_forward_from_hub
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -261,7 +261,7 @@ class Qwen3Attention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class Qwen3DecoderLayer(nn.Module):
|
class Qwen3DecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Qwen3Config, layer_idx: int):
|
def __init__(self, config: Qwen3Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -572,30 +572,17 @@ class Qwen3Model(Qwen3PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
layer_outputs = decoder_layer(
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
hidden_states,
|
||||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
attention_mask=causal_mask,
|
||||||
hidden_states,
|
position_ids=position_ids,
|
||||||
causal_mask,
|
past_key_value=past_key_values,
|
||||||
position_ids,
|
output_attentions=output_attentions,
|
||||||
past_key_values,
|
use_cache=use_cache,
|
||||||
output_attentions,
|
cache_position=cache_position,
|
||||||
use_cache,
|
position_embeddings=position_embeddings,
|
||||||
cache_position,
|
**flash_attn_kwargs,
|
||||||
position_embeddings,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=position_embeddings,
|
|
||||||
**flash_attn_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from torch.nn.init import _calculate_fan_in_and_fan_out
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -465,7 +466,7 @@ class SiglipMLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class SiglipEncoderLayer(nn.Module):
|
class SiglipEncoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]):
|
def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@@ -743,19 +744,12 @@ class SiglipEncoder(nn.Module):
|
|||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = encoder_layer(
|
||||||
encoder_layer.__call__,
|
hidden_states,
|
||||||
hidden_states,
|
attention_mask,
|
||||||
attention_mask,
|
output_attentions=output_attentions,
|
||||||
output_attentions,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from torch.nn.init import _calculate_fan_in_and_fan_out
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -356,7 +357,7 @@ class Siglip2MLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Siglip2EncoderLayer(nn.Module):
|
class Siglip2EncoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
|
def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@@ -462,19 +463,12 @@ class Siglip2Encoder(nn.Module):
|
|||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = encoder_layer(
|
||||||
encoder_layer.__call__,
|
hidden_states,
|
||||||
hidden_states,
|
attention_mask,
|
||||||
attention_mask,
|
output_attentions=output_attentions,
|
||||||
output_attentions,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = encoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
|||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -230,7 +231,7 @@ class Starcoder2Attention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2DecoderLayer(nn.Module):
|
class Starcoder2DecoderLayer(GradientCheckpointingLayer):
|
||||||
def __init__(self, config: Starcoder2Config, layer_idx: int):
|
def __init__(self, config: Starcoder2Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|||||||
@@ -542,6 +542,13 @@ def model_addition_debugger_context(*args, **kwargs):
|
|||||||
requires_backends(model_addition_debugger_context, ["torch"])
|
requires_backends(model_addition_debugger_context, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class GradientCheckpointingLayer(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
ROPE_INIT_FUNCTIONS = None
|
ROPE_INIT_FUNCTIONS = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user