Add GraniteMoeHybrid support for 4.0 (#37658)
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled

* initial config and MLA layer

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* first pass at decoder

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* completion of layers

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* modeling class

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* adding hybrid class to imports

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix imports granitemoehybrid

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix granitehybrid imports

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix granitehybrid import

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix generated modeling file

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* add some comments

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* minor fixes in layers

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* add sharedMLP layer

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* correct layer names

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fixes in mamba config

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix mamba config

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* change name of MLP layer

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix seq mizer layers

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* correct mamba config

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fixes in param names

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* enable hybrid model

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* update config

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix config granite hybrid

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix attention layer

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* cleanup to re-use mamba code

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* keep layer types

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* attention bias cleanup

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* update mamba layer name

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* first pass at tests

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* first pass at tests

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* use granite attention

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix: self attn weights

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* pass at making pos_emb optional

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* initialize self_attn only as needed

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* overwrite forward to create HybridMambaCache

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* Log invalid layer types

* Add attention outputs test

* Only emit attentions/logits if not None

* Fix config test hidden size divisibility

* mark granitmoehybrid as stateful

* Initialize mamba convolutional layers

* Formatting fixes

* config docstring, removed some unused attrs

* Fix missing arg in models test

* Fix create and check decoder model test

* support logits to keep in granitemoe

* regen to pass logits_to_keep

* Allow None or rope

* Fix gradient checkpointing

* Add granitemoehybrid as special cache for generate check

* Remove unused MLA refs

* Fix mamba layer mask

* Remove logits to keep from config

* Minor docstring nits

* Update licenses

* Enable cache by default

* map layer types to layer block type

* First pass at granite moe hybrid docs

* Ignore granite moe hybrid in valid checkpoint check

* Align attention interfaces

* regenerate modular granitemoeshared attention interface

* Align granite moe hybrid attn interface

* run formatting

* Handle mamba initialization

* avoid conditional attr defs

* Move hybrid layer validation to config

* Add placeholder integration tests

* Docs nits / Update model names

* Clean up forward conditions

* Use gradient checkpointing layer

* Remove some copied bamba tests + inherit

align test init

delete more tests

Use common layer init with bamba tests

finish test consolidation

* avoid redundant intermediate std var

* use @can_return_tuple

* Remove unused moe state

* make skipped test names consistent

* Fix docstring order

* Add missing toc

* Always create the shared mlp

* Fix name in docstring

* link preview model in docs

---------

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Co-authored-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Sukriti Sharma
2025-05-05 22:47:43 -06:00
committed by GitHub
parent fe29b8c487
commit 471958b620
21 changed files with 3150 additions and 544 deletions

View File

@@ -47,6 +47,11 @@ if is_torch_available():
class BambaModelTester:
config_class = BambaConfig
if is_torch_available():
model_class = BambaModel
for_causal_lm_class = BambaForCausalLM
def __init__(
self,
parent,
@@ -118,6 +123,7 @@ class BambaModelTester:
if self.use_labels:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
self._update_layer_configs()
config = self.get_config()
return config, input_ids, input_mask, token_labels
@@ -133,10 +139,12 @@ class BambaModelTester:
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
def get_config(self):
def _update_layer_configs(self):
"""Configures hidden layers and attn layer indices if they are not set."""
# Fix for SDPA tests, force at least 4 layers
if self.num_hidden_layers < 4:
self.num_hidden_layers = 4
if self.attn_layer_indices is None:
d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0]
if len(d) == 0:
@@ -144,7 +152,8 @@ class BambaModelTester:
d = d[-1] # get the largest divisor
self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)]
return BambaConfig(
def get_config(self, **kwargs):
return self.config_class(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
@@ -164,6 +173,7 @@ class BambaModelTester:
mamba_d_conv=self.mamba_d_conv,
mamba_expand=self.mamba_expand,
mamba_chunk_size=self.mamba_chunk_size,
**kwargs,
)
def create_and_check_model(
@@ -173,7 +183,7 @@ class BambaModelTester:
input_mask,
token_labels,
):
model = BambaModel(config=config)
model = self.model_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
@@ -187,7 +197,7 @@ class BambaModelTester:
input_mask,
token_labels,
):
model = BambaForCausalLM(config=config)
model = self.for_causal_lm_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
@@ -205,7 +215,7 @@ class BambaModelTester:
):
# config.is_decoder = True
# config.add_cross_attention = True
model = BambaForCausalLM(config=config)
model = self.for_causal_lm_class(config=config)
model.to(torch_device)
model.eval()
@@ -258,6 +268,7 @@ class BambaModelTester:
@require_torch
class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
model_tester_class = BambaModelTester
all_model_classes = (BambaModel, BambaForCausalLM) if is_torch_available() else ()
pipeline_model_mapping = (
{
@@ -276,8 +287,8 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
model_split_percents = [0.5, 0.7, 0.8]
def setUp(self):
self.model_tester = BambaModelTester(self)
self.config_tester = ConfigTester(self, config_class=BambaConfig, hidden_size=64)
self.model_tester = self.model_tester_class(self)
self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64)
def test_config(self):
self.config_tester.run_common_tests()