Introduce GradientCheckpointingLayer (#37223)
* GradientCheckpointingLayer * trigger * Move GC layer to a separate file * Update import * Expose and document GC layer * Fix dummy * Apply to llama-based models * Update modulars * Update a few more models for consistency * Update glm4 * Update Janus
This commit is contained in:
committed by
GitHub
parent
413f9bbf80
commit
9167fadab9
@@ -19,7 +19,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@@ -28,6 +27,7 @@ from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
@@ -590,7 +590,7 @@ class AriaTextAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AriaTextDecoderLayer(nn.Module):
|
||||
class AriaTextDecoderLayer(GradientCheckpointingLayer):
|
||||
"""
|
||||
Aria Text Decoder Layer.
|
||||
|
||||
@@ -940,30 +940,17 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
partial(decoder_layer.__call__, **flash_attn_kwargs),
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user