Add padding-free to Granite hybrid moe models (#39677)
* start fixing kwarg handling * fmt * updates padding free tests * docs * add missing kwargs modeling_granitemoe.py * run modular util * rm unrelated changes from modular util
This commit is contained in:
@@ -48,6 +48,32 @@ for i in output:
|
||||
|
||||
This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).
|
||||
|
||||
## Notes
|
||||
|
||||
- `GraniteMoeHybridForCausalLM` supports padding-free training which concatenates distinct training examples while still processing inputs as separate batches. It can significantly accelerate inference by [~2x](https://github.com/huggingface/transformers/pull/35861#issue-2807873129) (depending on model and data distribution) and reduce memory-usage if there are examples of varying lengths by avoiding unnecessary compute and memory overhead from padding tokens.
|
||||
|
||||
Padding-free training requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d` packages and the following arguments must be passed to the model in addition to `input_ids` and `labels`.
|
||||
|
||||
- `position_ids: torch.LongTensor`: the position index of each token in each sequence.
|
||||
- `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
|
||||
- Each of the [`FlashAttentionKwargs`]
|
||||
- `cu_seq_lens_q: torch.LongTensor`: the cumulative sequence lengths of all queries.
|
||||
- `cu_seq_lens_k: torch.LongTensor`: the cumulative sequence lengths of all keys.
|
||||
- `max_length_q: int`: the longest query length in the batch.
|
||||
- `max_length_k: int`: the longest key length in the batch.
|
||||
|
||||
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] programmatically generates the set of additional arguments above using `return_seq_idx=True` and `return_flash_attn_kwargs=True`. See the [Improving Hugging Face Training Efficiency Through Packing with Flash Attention](https://huggingface.co/blog/packing-with-FA2) blog post for additional information.
|
||||
|
||||
```python
|
||||
from transformers import DataCollatorWithFlattening
|
||||
|
||||
# Example of using padding-free training
|
||||
data_collator = DataCollatorWithFlattening(
|
||||
tokenizer=tokenizer,
|
||||
return_seq_idx=True,
|
||||
return_flash_attn_kwargs=True
|
||||
)
|
||||
```
|
||||
|
||||
## GraniteMoeHybridConfig
|
||||
|
||||
@@ -61,4 +87,4 @@ This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co
|
||||
## GraniteMoeHybridForCausalLM
|
||||
|
||||
[[autodoc]] GraniteMoeHybridForCausalLM
|
||||
- forward
|
||||
- forward
|
||||
|
||||
@@ -641,6 +641,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -947,6 +948,7 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
|
||||
output_router_logits=output_router_logits,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Only compute necessary logits
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -34,6 +34,7 @@ from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
||||
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
|
||||
@@ -918,6 +919,31 @@ class GraniteMoeHybridMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GraniteFlashAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||
|
||||
Attributes:
|
||||
cu_seq_lens_q (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for query state.
|
||||
cu_seq_lens_k (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for key state.
|
||||
max_length_q (`int`):
|
||||
Maximum sequence length for query state.
|
||||
max_length_k (`int`):
|
||||
Maximum sequence length for key state.
|
||||
seq_idx (`torch.IntTensor):
|
||||
Index of each packed sequence.
|
||||
"""
|
||||
|
||||
cu_seq_lens_q: torch.LongTensor
|
||||
cu_seq_lens_k: torch.LongTensor
|
||||
max_length_q: int
|
||||
max_length_k: int
|
||||
seq_idx: torch.IntTensor
|
||||
|
||||
|
||||
class GraniteMoeHybridRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
@@ -1125,7 +1151,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -1149,8 +1175,8 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for
|
||||
padding-free training and/or improve torch.compile performance.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@@ -1161,6 +1187,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
|
||||
cache_position=cache_position,
|
||||
cache_params=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
# No attention weights for state space layers
|
||||
self_attn_weights = None
|
||||
@@ -1303,6 +1330,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||
) -> Union[tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -1374,6 +1402,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
|
||||
cache_position=cache_position,
|
||||
output_router_logits=output_router_logits,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@@ -1706,6 +1735,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
|
||||
output_router_logits=output_router_logits,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Only compute necessary logits
|
||||
|
||||
@@ -20,10 +20,12 @@ from torch import nn
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ..bamba.configuration_bamba import BambaConfig
|
||||
from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache
|
||||
from ..granitemoeshared.modeling_granitemoeshared import (
|
||||
GraniteFlashAttentionKwargs,
|
||||
GraniteMoeSharedAttention,
|
||||
GraniteMoeSharedDecoderLayer,
|
||||
GraniteMoeSharedForCausalLM,
|
||||
@@ -84,7 +86,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -108,8 +110,8 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for
|
||||
padding-free training and/or improve torch.compile performance.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@@ -120,6 +122,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
|
||||
cache_position=cache_position,
|
||||
cache_params=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
# No attention weights for state space layers
|
||||
self_attn_weights = None
|
||||
@@ -198,6 +201,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||
) -> Union[tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -269,6 +273,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
|
||||
cache_position=cache_position,
|
||||
output_router_logits=output_router_logits,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -33,6 +33,7 @@ from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
||||
from .configuration_granitemoeshared import GraniteMoeSharedConfig
|
||||
|
||||
@@ -46,6 +47,31 @@ if is_torch_flex_attn_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GraniteFlashAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||
|
||||
Attributes:
|
||||
cu_seq_lens_q (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for query state.
|
||||
cu_seq_lens_k (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for key state.
|
||||
max_length_q (`int`):
|
||||
Maximum sequence length for query state.
|
||||
max_length_k (`int`):
|
||||
Maximum sequence length for key state.
|
||||
seq_idx (`torch.IntTensor):
|
||||
Index of each packed sequence.
|
||||
"""
|
||||
|
||||
cu_seq_lens_q: torch.LongTensor
|
||||
cu_seq_lens_k: torch.LongTensor
|
||||
max_length_q: int
|
||||
max_length_k: int
|
||||
seq_idx: torch.IntTensor
|
||||
|
||||
|
||||
class GraniteMoeSharedMLP(nn.Module):
|
||||
"""
|
||||
MLP layer for shared experts
|
||||
@@ -431,7 +457,7 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -455,8 +481,8 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
|
||||
padding-free training and/or improve torch.compile performance.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
@@ -593,6 +619,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -979,6 +1006,7 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
|
||||
output_router_logits=output_router_logits,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Only compute necessary logits
|
||||
|
||||
@@ -13,13 +13,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import logging
|
||||
from ..granitemoe.modeling_granitemoe import (
|
||||
GraniteMoeDecoderLayer,
|
||||
@@ -33,6 +34,31 @@ from .configuration_granitemoeshared import GraniteMoeSharedConfig
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GraniteFlashAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||
|
||||
Attributes:
|
||||
cu_seq_lens_q (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for query state.
|
||||
cu_seq_lens_k (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for key state.
|
||||
max_length_q (`int`):
|
||||
Maximum sequence length for query state.
|
||||
max_length_k (`int`):
|
||||
Maximum sequence length for key state.
|
||||
seq_idx (`torch.IntTensor):
|
||||
Index of each packed sequence.
|
||||
"""
|
||||
|
||||
cu_seq_lens_q: torch.LongTensor
|
||||
cu_seq_lens_k: torch.LongTensor
|
||||
max_length_q: int
|
||||
max_length_k: int
|
||||
seq_idx: torch.IntTensor
|
||||
|
||||
|
||||
class GraniteMoeSharedMLP(nn.Module):
|
||||
"""
|
||||
MLP layer for shared experts
|
||||
@@ -75,7 +101,7 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -99,8 +125,8 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
|
||||
padding-free training and/or improve torch.compile performance.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
@@ -551,6 +551,15 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
# Ensure inputs_dict also has labels in it, as their presence/absence can induce
|
||||
# dtype conversions. This also lets us compare losses.
|
||||
labels = inputs_dict["input_ids"].clone()
|
||||
# Mask padding tokens
|
||||
labels[~dummy_attention_mask.bool()] = -100
|
||||
# Also need to mask the first non-trivial token to match the padding-free batch.
|
||||
first_nonneg_idx = (labels >= 0).int().argmax(dim=1)
|
||||
labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100
|
||||
inputs_dict["labels"] = labels
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
@@ -586,6 +595,10 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
loss_padded = res_padded.loss
|
||||
loss_padfree = res_padfree.loss
|
||||
torch.testing.assert_close(loss_padded, loss_padfree)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user