Add padding-free to Granite hybrid moe models (#39677)

* start fixing kwarg handling

* fmt

* updates padding free tests

* docs

* add missing kwargs modeling_granitemoe.py

* run modular util

* rm unrelated changes from modular util
This commit is contained in:
Garrett Goon
2025-07-25 14:10:50 -04:00
committed by GitHub
parent d6e9f71a6e
commit 97f8c71f52
7 changed files with 146 additions and 16 deletions

View File

@@ -48,6 +48,32 @@ for i in output:
This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944). This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).
## Notes
- `GraniteMoeHybridForCausalLM` supports padding-free training which concatenates distinct training examples while still processing inputs as separate batches. It can significantly accelerate inference by [~2x](https://github.com/huggingface/transformers/pull/35861#issue-2807873129) (depending on model and data distribution) and reduce memory-usage if there are examples of varying lengths by avoiding unnecessary compute and memory overhead from padding tokens.
Padding-free training requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d` packages and the following arguments must be passed to the model in addition to `input_ids` and `labels`.
- `position_ids: torch.LongTensor`: the position index of each token in each sequence.
- `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
- Each of the [`FlashAttentionKwargs`]
- `cu_seq_lens_q: torch.LongTensor`: the cumulative sequence lengths of all queries.
- `cu_seq_lens_k: torch.LongTensor`: the cumulative sequence lengths of all keys.
- `max_length_q: int`: the longest query length in the batch.
- `max_length_k: int`: the longest key length in the batch.
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] programmatically generates the set of additional arguments above using `return_seq_idx=True` and `return_flash_attn_kwargs=True`. See the [Improving Hugging Face Training Efficiency Through Packing with Flash Attention](https://huggingface.co/blog/packing-with-FA2) blog post for additional information.
```python
from transformers import DataCollatorWithFlattening
# Example of using padding-free training
data_collator = DataCollatorWithFlattening(
tokenizer=tokenizer,
return_seq_idx=True,
return_flash_attn_kwargs=True
)
```
## GraniteMoeHybridConfig ## GraniteMoeHybridConfig

View File

@@ -641,6 +641,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
output_router_logits: Optional[bool] = None, output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[tuple, BaseModelOutputWithPast]: ) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -947,6 +948,7 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
output_router_logits=output_router_logits, output_router_logits=output_router_logits,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position, cache_position=cache_position,
**kwargs,
) )
# Only compute necessary logits # Only compute necessary logits

View File

@@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, TypedDict, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -34,6 +34,7 @@ from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_granitemoehybrid import GraniteMoeHybridConfig from .configuration_granitemoehybrid import GraniteMoeHybridConfig
@@ -918,6 +919,31 @@ class GraniteMoeHybridMLP(nn.Module):
return hidden_states return hidden_states
class GraniteFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.
Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""
cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor
class GraniteMoeHybridRMSNorm(nn.Module): class GraniteMoeHybridRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -1125,7 +1151,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False, output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, **kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -1149,8 +1175,8 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head. with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for
into the model padding-free training and/or improve torch.compile performance.
""" """
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@@ -1161,6 +1187,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position, cache_position=cache_position,
cache_params=past_key_value, cache_params=past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
**kwargs,
) )
# No attention weights for state space layers # No attention weights for state space layers
self_attn_weights = None self_attn_weights = None
@@ -1303,6 +1330,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
output_router_logits: Optional[bool] = None, output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]: ) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -1374,6 +1402,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel):
cache_position=cache_position, cache_position=cache_position,
output_router_logits=output_router_logits, output_router_logits=output_router_logits,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@@ -1706,6 +1735,7 @@ class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMix
output_router_logits=output_router_logits, output_router_logits=output_router_logits,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position, cache_position=cache_position,
**kwargs,
) )
# Only compute necessary logits # Only compute necessary logits

View File

@@ -20,10 +20,12 @@ from torch import nn
from ...cache_utils import Cache from ...cache_utils import Cache
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, logging from ...utils import auto_docstring, can_return_tuple, logging
from ..bamba.configuration_bamba import BambaConfig from ..bamba.configuration_bamba import BambaConfig
from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache
from ..granitemoeshared.modeling_granitemoeshared import ( from ..granitemoeshared.modeling_granitemoeshared import (
GraniteFlashAttentionKwargs,
GraniteMoeSharedAttention, GraniteMoeSharedAttention,
GraniteMoeSharedDecoderLayer, GraniteMoeSharedDecoderLayer,
GraniteMoeSharedForCausalLM, GraniteMoeSharedForCausalLM,
@@ -84,7 +86,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False, output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, **kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -108,8 +110,8 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head. with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for
into the model padding-free training and/or improve torch.compile performance.
""" """
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@@ -120,6 +122,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
cache_position=cache_position, cache_position=cache_position,
cache_params=past_key_value, cache_params=past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
**kwargs,
) )
# No attention weights for state space layers # No attention weights for state space layers
self_attn_weights = None self_attn_weights = None
@@ -198,6 +201,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
output_router_logits: Optional[bool] = None, output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]: ) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -269,6 +273,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel):
cache_position=cache_position, cache_position=cache_position,
output_router_logits=output_router_logits, output_router_logits=output_router_logits,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, Optional, Union from typing import Callable, Optional, TypedDict, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -33,6 +33,7 @@ from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_granitemoeshared import GraniteMoeSharedConfig from .configuration_granitemoeshared import GraniteMoeSharedConfig
@@ -46,6 +47,31 @@ if is_torch_flex_attn_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class GraniteFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.
Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""
cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor
class GraniteMoeSharedMLP(nn.Module): class GraniteMoeSharedMLP(nn.Module):
""" """
MLP layer for shared experts MLP layer for shared experts
@@ -431,7 +457,7 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False, output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, **kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -455,8 +481,8 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head. with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
into the model padding-free training and/or improve torch.compile performance.
""" """
residual = hidden_states residual = hidden_states
@@ -593,6 +619,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
output_router_logits: Optional[bool] = None, output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[tuple, BaseModelOutputWithPast]: ) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -979,6 +1006,7 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
output_router_logits=output_router_logits, output_router_logits=output_router_logits,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position, cache_position=cache_position,
**kwargs,
) )
# Only compute necessary logits # Only compute necessary logits

View File

@@ -13,13 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional, TypedDict
import torch import torch
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache from ...cache_utils import Cache
from ...processing_utils import Unpack
from ...utils import logging from ...utils import logging
from ..granitemoe.modeling_granitemoe import ( from ..granitemoe.modeling_granitemoe import (
GraniteMoeDecoderLayer, GraniteMoeDecoderLayer,
@@ -33,6 +34,31 @@ from .configuration_granitemoeshared import GraniteMoeSharedConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class GraniteFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.
Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""
cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor
class GraniteMoeSharedMLP(nn.Module): class GraniteMoeSharedMLP(nn.Module):
""" """
MLP layer for shared experts MLP layer for shared experts
@@ -75,7 +101,7 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False, output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, **kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -99,8 +125,8 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head. with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
into the model padding-free training and/or improve torch.compile performance.
""" """
residual = hidden_states residual = hidden_states

View File

@@ -551,6 +551,15 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"] dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
# Ensure inputs_dict also has labels in it, as their presence/absence can induce
# dtype conversions. This also lets us compare losses.
labels = inputs_dict["input_ids"].clone()
# Mask padding tokens
labels[~dummy_attention_mask.bool()] = -100
# Also need to mask the first non-trivial token to match the padding-free batch.
first_nonneg_idx = (labels >= 0).int().argmax(dim=1)
labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100
inputs_dict["labels"] = labels
model = ( model = (
model_class.from_pretrained( model_class.from_pretrained(
@@ -586,6 +595,10 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
tol = torch.finfo(torch.float16).eps tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
loss_padded = res_padded.loss
loss_padfree = res_padfree.loss
torch.testing.assert_close(loss_padded, loss_padfree)
@slow @slow
@require_torch @require_torch