Add padding-free to Granite hybrid moe models (#39677)
* start fixing kwarg handling * fmt * updates padding free tests * docs * add missing kwargs modeling_granitemoe.py * run modular util * rm unrelated changes from modular util
This commit is contained in:
@@ -48,6 +48,32 @@ for i in output:
|
|||||||
|
|
||||||
This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).
|
This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- `GraniteMoeHybridForCausalLM` supports padding-free training which concatenates distinct training examples while still processing inputs as separate batches. It can significantly accelerate inference by [~2x](https://github.com/huggingface/transformers/pull/35861#issue-2807873129) (depending on model and data distribution) and reduce memory-usage if there are examples of varying lengths by avoiding unnecessary compute and memory overhead from padding tokens.
|
||||||
|
|
||||||
|
Padding-free training requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d` packages and the following arguments must be passed to the model in addition to `input_ids` and `labels`.
|
||||||
|
|
||||||
|
- `position_ids: torch.LongTensor`: the position index of each token in each sequence.
|
||||||
|
- `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
|
||||||
|
- Each of the [`FlashAttentionKwargs`]
|
||||||
|
- `cu_seq_lens_q: torch.LongTensor`: the cumulative sequence lengths of all queries.
|
||||||
|
- `cu_seq_lens_k: torch.LongTensor`: the cumulative sequence lengths of all keys.
|
||||||
|
- `max_length_q: int`: the longest query length in the batch.
|
||||||
|
- `max_length_k: int`: the longest key length in the batch.
|
||||||
|
|
||||||
|
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] programmatically generates the set of additional arguments above using `return_seq_idx=True` and `return_flash_attn_kwargs=True`. See the [Improving Hugging Face Training Efficiency Through Packing with Flash Attention](https://huggingface.co/blog/packing-with-FA2) blog post for additional information.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import DataCollatorWithFlattening
|
||||||
|
|
||||||
|
# Example of using padding-free training
|
||||||
|
data_collator = DataCollatorWithFlattening(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
return_seq_idx=True,
|
||||||
|
return_flash_attn_kwargs=True
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## GraniteMoeHybridConfig
|
## GraniteMoeHybridConfig
|
||||||
|
|
||||||
@@ -61,4 +87,4 @@ This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co
|
|||||||
## GraniteMoeHybridForCausalLM
|
## GraniteMoeHybridForCausalLM
|
||||||
|
|
||||||
[[autodoc]] GraniteMoeHybridForCausalLM
|
[[autodoc]] GraniteMoeHybridForCausalLM
|
||||||
- forward
|
- forward
|
||||||
|
|||||||
@@ -641,6 +641,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
|||||||
output_router_logits: Optional[bool] = None,
|
output_router_logits: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> Union[tuple, BaseModelOutputWithPast]:
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -947,6 +948,7 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
|
|||||||
output_router_logits=output_router_logits,
|
output_router_logits=output_router_logits,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only compute necessary logits
|
# Only compute necessary logits
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -34,6 +34,7 @@ from ...modeling_layers import GradientCheckpointingLayer
|
|||||||
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
|
from ...processing_utils import Unpack
|
||||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||||
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
||||||
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
|
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
|
||||||
@@ -918,6 +919,31 @@ class GraniteMoeHybridMLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteFlashAttentionKwargs(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||||
|
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
cu_seq_lens_q (`torch.LongTensor`)
|
||||||
|
Gets cumulative sequence length for query state.
|
||||||
|
cu_seq_lens_k (`torch.LongTensor`)
|
||||||
|
Gets cumulative sequence length for key state.
|
||||||
|
max_length_q (`int`):
|
||||||
|
Maximum sequence length for query state.
|
||||||
|
max_length_k (`int`):
|
||||||
|
Maximum sequence length for key state.
|
||||||
|
seq_idx (`torch.IntTensor):
|
||||||
|
Index of each packed sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cu_seq_lens_q: torch.LongTensor
|
||||||
|
cu_seq_lens_k: torch.LongTensor
|
||||||
|
max_length_q: int
|
||||||
|
max_length_k: int
|
||||||
|
seq_idx: torch.IntTensor
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeHybridRMSNorm(nn.Module):
|
class GraniteMoeHybridRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
@@ -1125,7 +1151,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
output_router_logits: Optional[bool] = False,
|
output_router_logits: Optional[bool] = False,
|
||||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
**kwargs,
|
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1149,8 +1175,8 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
|
|||||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||||
with `head_dim` being the embedding dimension of each attention head.
|
with `head_dim` being the embedding dimension of each attention head.
|
||||||
kwargs (`dict`, *optional*):
|
kwargs (`dict`, *optional*):
|
||||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for
|
||||||
into the model
|
padding-free training and/or improve torch.compile performance.
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@@ -1161,6 +1187,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
|
|||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
cache_params=past_key_value,
|
cache_params=past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# No attention weights for state space layers
|
# No attention weights for state space layers
|
||||||
self_attn_weights = None
|
self_attn_weights = None
|
||||||
@@ -1303,6 +1330,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
|
|||||||
output_router_logits: Optional[bool] = None,
|
output_router_logits: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||||
) -> Union[tuple, BaseModelOutputWithPast]:
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -1374,6 +1402,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
|
|||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
output_router_logits=output_router_logits,
|
output_router_logits=output_router_logits,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@@ -1706,6 +1735,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
|
|||||||
output_router_logits=output_router_logits,
|
output_router_logits=output_router_logits,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only compute necessary logits
|
# Only compute necessary logits
|
||||||
|
|||||||
@@ -20,10 +20,12 @@ from torch import nn
|
|||||||
|
|
||||||
from ...cache_utils import Cache
|
from ...cache_utils import Cache
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
|
||||||
|
from ...processing_utils import Unpack
|
||||||
from ...utils import auto_docstring, can_return_tuple, logging
|
from ...utils import auto_docstring, can_return_tuple, logging
|
||||||
from ..bamba.configuration_bamba import BambaConfig
|
from ..bamba.configuration_bamba import BambaConfig
|
||||||
from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache
|
from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache
|
||||||
from ..granitemoeshared.modeling_granitemoeshared import (
|
from ..granitemoeshared.modeling_granitemoeshared import (
|
||||||
|
GraniteFlashAttentionKwargs,
|
||||||
GraniteMoeSharedAttention,
|
GraniteMoeSharedAttention,
|
||||||
GraniteMoeSharedDecoderLayer,
|
GraniteMoeSharedDecoderLayer,
|
||||||
GraniteMoeSharedForCausalLM,
|
GraniteMoeSharedForCausalLM,
|
||||||
@@ -84,7 +86,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
output_router_logits: Optional[bool] = False,
|
output_router_logits: Optional[bool] = False,
|
||||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
**kwargs,
|
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -108,8 +110,8 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
|
|||||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||||
with `head_dim` being the embedding dimension of each attention head.
|
with `head_dim` being the embedding dimension of each attention head.
|
||||||
kwargs (`dict`, *optional*):
|
kwargs (`dict`, *optional*):
|
||||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for
|
||||||
into the model
|
padding-free training and/or improve torch.compile performance.
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@@ -120,6 +122,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
|
|||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
cache_params=past_key_value,
|
cache_params=past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# No attention weights for state space layers
|
# No attention weights for state space layers
|
||||||
self_attn_weights = None
|
self_attn_weights = None
|
||||||
@@ -198,6 +201,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
|
|||||||
output_router_logits: Optional[bool] = None,
|
output_router_logits: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||||
) -> Union[tuple, BaseModelOutputWithPast]:
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -269,6 +273,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
|
|||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
output_router_logits=output_router_logits,
|
output_router_logits=output_router_logits,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -33,6 +33,7 @@ from ...modeling_layers import GradientCheckpointingLayer
|
|||||||
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
|
from ...processing_utils import Unpack
|
||||||
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
||||||
from .configuration_granitemoeshared import GraniteMoeSharedConfig
|
from .configuration_granitemoeshared import GraniteMoeSharedConfig
|
||||||
|
|
||||||
@@ -46,6 +47,31 @@ if is_torch_flex_attn_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteFlashAttentionKwargs(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||||
|
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
cu_seq_lens_q (`torch.LongTensor`)
|
||||||
|
Gets cumulative sequence length for query state.
|
||||||
|
cu_seq_lens_k (`torch.LongTensor`)
|
||||||
|
Gets cumulative sequence length for key state.
|
||||||
|
max_length_q (`int`):
|
||||||
|
Maximum sequence length for query state.
|
||||||
|
max_length_k (`int`):
|
||||||
|
Maximum sequence length for key state.
|
||||||
|
seq_idx (`torch.IntTensor):
|
||||||
|
Index of each packed sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cu_seq_lens_q: torch.LongTensor
|
||||||
|
cu_seq_lens_k: torch.LongTensor
|
||||||
|
max_length_q: int
|
||||||
|
max_length_k: int
|
||||||
|
seq_idx: torch.IntTensor
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeSharedMLP(nn.Module):
|
class GraniteMoeSharedMLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
MLP layer for shared experts
|
MLP layer for shared experts
|
||||||
@@ -431,7 +457,7 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
output_router_logits: Optional[bool] = False,
|
output_router_logits: Optional[bool] = False,
|
||||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
**kwargs,
|
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -455,8 +481,8 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
|
|||||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||||
with `head_dim` being the embedding dimension of each attention head.
|
with `head_dim` being the embedding dimension of each attention head.
|
||||||
kwargs (`dict`, *optional*):
|
kwargs (`dict`, *optional*):
|
||||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
|
||||||
into the model
|
padding-free training and/or improve torch.compile performance.
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
@@ -593,6 +619,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
|
|||||||
output_router_logits: Optional[bool] = None,
|
output_router_logits: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> Union[tuple, BaseModelOutputWithPast]:
|
) -> Union[tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -979,6 +1006,7 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
|
|||||||
output_router_logits=output_router_logits,
|
output_router_logits=output_router_logits,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only compute necessary logits
|
# Only compute necessary logits
|
||||||
|
|||||||
@@ -13,13 +13,14 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Optional
|
from typing import Optional, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache
|
from ...cache_utils import Cache
|
||||||
|
from ...processing_utils import Unpack
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..granitemoe.modeling_granitemoe import (
|
from ..granitemoe.modeling_granitemoe import (
|
||||||
GraniteMoeDecoderLayer,
|
GraniteMoeDecoderLayer,
|
||||||
@@ -33,6 +34,31 @@ from .configuration_granitemoeshared import GraniteMoeSharedConfig
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteFlashAttentionKwargs(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||||
|
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
cu_seq_lens_q (`torch.LongTensor`)
|
||||||
|
Gets cumulative sequence length for query state.
|
||||||
|
cu_seq_lens_k (`torch.LongTensor`)
|
||||||
|
Gets cumulative sequence length for key state.
|
||||||
|
max_length_q (`int`):
|
||||||
|
Maximum sequence length for query state.
|
||||||
|
max_length_k (`int`):
|
||||||
|
Maximum sequence length for key state.
|
||||||
|
seq_idx (`torch.IntTensor):
|
||||||
|
Index of each packed sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cu_seq_lens_q: torch.LongTensor
|
||||||
|
cu_seq_lens_k: torch.LongTensor
|
||||||
|
max_length_q: int
|
||||||
|
max_length_k: int
|
||||||
|
seq_idx: torch.IntTensor
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeSharedMLP(nn.Module):
|
class GraniteMoeSharedMLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
MLP layer for shared experts
|
MLP layer for shared experts
|
||||||
@@ -75,7 +101,7 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
output_router_logits: Optional[bool] = False,
|
output_router_logits: Optional[bool] = False,
|
||||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
**kwargs,
|
**kwargs: Unpack[GraniteFlashAttentionKwargs],
|
||||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -99,8 +125,8 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
|
|||||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||||
with `head_dim` being the embedding dimension of each attention head.
|
with `head_dim` being the embedding dimension of each attention head.
|
||||||
kwargs (`dict`, *optional*):
|
kwargs (`dict`, *optional*):
|
||||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
|
||||||
into the model
|
padding-free training and/or improve torch.compile performance.
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
|
|||||||
@@ -551,6 +551,15 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||||
|
# Ensure inputs_dict also has labels in it, as their presence/absence can induce
|
||||||
|
# dtype conversions. This also lets us compare losses.
|
||||||
|
labels = inputs_dict["input_ids"].clone()
|
||||||
|
# Mask padding tokens
|
||||||
|
labels[~dummy_attention_mask.bool()] = -100
|
||||||
|
# Also need to mask the first non-trivial token to match the padding-free batch.
|
||||||
|
first_nonneg_idx = (labels >= 0).int().argmax(dim=1)
|
||||||
|
labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100
|
||||||
|
inputs_dict["labels"] = labels
|
||||||
|
|
||||||
model = (
|
model = (
|
||||||
model_class.from_pretrained(
|
model_class.from_pretrained(
|
||||||
@@ -586,6 +595,10 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
tol = torch.finfo(torch.float16).eps
|
tol = torch.finfo(torch.float16).eps
|
||||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||||
|
|
||||||
|
loss_padded = res_padded.loss
|
||||||
|
loss_padfree = res_padfree.loss
|
||||||
|
torch.testing.assert_close(loss_padded, loss_padfree)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user