Add padding-free to Granite hybrid moe models (#39677)

* start fixing kwarg handling

* fmt

* updates padding free tests

* docs

* add missing kwargs modeling_granitemoe.py

* run modular util

* rm unrelated changes from modular util
This commit is contained in:
Garrett Goon
2025-07-25 14:10:50 -04:00
committed by GitHub
parent d6e9f71a6e
commit 97f8c71f52
7 changed files with 146 additions and 16 deletions

View File

@@ -48,6 +48,32 @@ for i in output:
This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).
## Notes
- `GraniteMoeHybridForCausalLM` supports padding-free training which concatenates distinct training examples while still processing inputs as separate batches. It can significantly accelerate inference by [~2x](https://github.com/huggingface/transformers/pull/35861#issue-2807873129) (depending on model and data distribution) and reduce memory-usage if there are examples of varying lengths by avoiding unnecessary compute and memory overhead from padding tokens.
Padding-free training requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d` packages and the following arguments must be passed to the model in addition to `input_ids` and `labels`.
- `position_ids: torch.LongTensor`: the position index of each token in each sequence.
- `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
- Each of the [`FlashAttentionKwargs`]
- `cu_seq_lens_q: torch.LongTensor`: the cumulative sequence lengths of all queries.
- `cu_seq_lens_k: torch.LongTensor`: the cumulative sequence lengths of all keys.
- `max_length_q: int`: the longest query length in the batch.
- `max_length_k: int`: the longest key length in the batch.
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] programmatically generates the set of additional arguments above using `return_seq_idx=True` and `return_flash_attn_kwargs=True`. See the [Improving Hugging Face Training Efficiency Through Packing with Flash Attention](https://huggingface.co/blog/packing-with-FA2) blog post for additional information.
```python
from transformers import DataCollatorWithFlattening
# Example of using padding-free training
data_collator = DataCollatorWithFlattening(
tokenizer=tokenizer,
return_seq_idx=True,
return_flash_attn_kwargs=True
)
```
## GraniteMoeHybridConfig
@@ -61,4 +87,4 @@ This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co
## GraniteMoeHybridForCausalLM
[[autodoc]] GraniteMoeHybridForCausalLM
- forward
- forward

View File

@@ -641,6 +641,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -947,6 +948,7 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Only compute necessary logits

View File

@@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TypedDict, Union
import torch
import torch.nn.functional as F
@@ -34,6 +34,7 @@ from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
@@ -918,6 +919,31 @@ class GraniteMoeHybridMLP(nn.Module):
return hidden_states
class GraniteFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.
Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""
cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor
class GraniteMoeHybridRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
@@ -1125,7 +1151,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -1149,8 +1175,8 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@@ -1161,6 +1187,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
cache_params=past_key_value,
attention_mask=attention_mask,
**kwargs,
)
# No attention weights for state space layers
self_attn_weights = None
@@ -1303,6 +1330,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -1374,6 +1402,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
cache_position=cache_position,
output_router_logits=output_router_logits,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
@@ -1706,6 +1735,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Only compute necessary logits

View File

@@ -20,10 +20,12 @@ from torch import nn
from ...cache_utils import Cache
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, logging
from ..bamba.configuration_bamba import BambaConfig
from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache
from ..granitemoeshared.modeling_granitemoeshared import (
GraniteFlashAttentionKwargs,
GraniteMoeSharedAttention,
GraniteMoeSharedDecoderLayer,
GraniteMoeSharedForCausalLM,
@@ -84,7 +86,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -108,8 +110,8 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@@ -120,6 +122,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
cache_position=cache_position,
cache_params=past_key_value,
attention_mask=attention_mask,
**kwargs,
)
# No attention weights for state space layers
self_attn_weights = None
@@ -198,6 +201,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -269,6 +273,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
cache_position=cache_position,
output_router_logits=output_router_logits,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]

View File

@@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
from typing import Callable, Optional, TypedDict, Union
import torch
import torch.nn.functional as F
@@ -33,6 +33,7 @@ from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_granitemoeshared import GraniteMoeSharedConfig
@@ -46,6 +47,31 @@ if is_torch_flex_attn_available():
logger = logging.get_logger(__name__)
class GraniteFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.
Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""
cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor
class GraniteMoeSharedMLP(nn.Module):
"""
MLP layer for shared experts
@@ -431,7 +457,7 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -455,8 +481,8 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""
residual = hidden_states
@@ -593,6 +619,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -979,6 +1006,7 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
# Only compute necessary logits

View File

@@ -13,13 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, TypedDict
import torch
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...processing_utils import Unpack
from ...utils import logging
from ..granitemoe.modeling_granitemoe import (
GraniteMoeDecoderLayer,
@@ -33,6 +34,31 @@ from .configuration_granitemoeshared import GraniteMoeSharedConfig
logger = logging.get_logger(__name__)
class GraniteFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.
Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""
cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor
class GraniteMoeSharedMLP(nn.Module):
"""
MLP layer for shared experts
@@ -75,7 +101,7 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -99,8 +125,8 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""
residual = hidden_states

View File

@@ -551,6 +551,15 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
# Ensure inputs_dict also has labels in it, as their presence/absence can induce
# dtype conversions. This also lets us compare losses.
labels = inputs_dict["input_ids"].clone()
# Mask padding tokens
labels[~dummy_attention_mask.bool()] = -100
# Also need to mask the first non-trivial token to match the padding-free batch.
first_nonneg_idx = (labels >= 0).int().argmax(dim=1)
labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100
inputs_dict["labels"] = labels
model = (
model_class.from_pretrained(
@@ -586,6 +595,10 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
loss_padded = res_padded.loss
loss_padfree = res_padfree.loss
torch.testing.assert_close(loss_padded, loss_padfree)
@slow
@require_torch