[modular] Follow global indexing and attribute setting, and their dependencies (#39180)

* export global indexing statements

* add example

* style

* examples
This commit is contained in:
Cyril Vallez
2025-07-07 14:36:43 +02:00
committed by GitHub
parent 8570bc29f3
commit 5348fbc005
8 changed files with 254 additions and 85 deletions

View File

@@ -11,9 +11,9 @@ import torch
from torch import nn
from ...cache_utils import Cache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs
from .configuration_switch_function import SwitchFunctionConfig
@@ -72,7 +72,7 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
@@ -123,8 +123,8 @@ class SwitchFunctionAttention(nn.Module):
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)