Add GraniteMoeHybrid support for 4.0 (#37658)
* initial config and MLA layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * first pass at decoder Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * completion of layers Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * modeling class Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * adding hybrid class to imports Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix imports granitemoehybrid Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix granitehybrid imports Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix granitehybrid import Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix generated modeling file Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * add some comments Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * minor fixes in layers Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * add sharedMLP layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * correct layer names Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fixes in mamba config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix mamba config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * change name of MLP layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix seq mizer layers Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * correct mamba config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fixes in param names Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * enable hybrid model Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * update config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix config granite hybrid Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix attention layer Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * cleanup to re-use mamba code Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * keep layer types Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * attention bias cleanup Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * update mamba layer name Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * first pass at tests Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * first pass at tests Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * use granite attention Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix: self attn weights Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * pass at making pos_emb optional Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * initialize self_attn only as needed Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * overwrite forward to create HybridMambaCache Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * Log invalid layer types * Add attention outputs test * Only emit attentions/logits if not None * Fix config test hidden size divisibility * mark granitmoehybrid as stateful * Initialize mamba convolutional layers * Formatting fixes * config docstring, removed some unused attrs * Fix missing arg in models test * Fix create and check decoder model test * support logits to keep in granitemoe * regen to pass logits_to_keep * Allow None or rope * Fix gradient checkpointing * Add granitemoehybrid as special cache for generate check * Remove unused MLA refs * Fix mamba layer mask * Remove logits to keep from config * Minor docstring nits * Update licenses * Enable cache by default * map layer types to layer block type * First pass at granite moe hybrid docs * Ignore granite moe hybrid in valid checkpoint check * Align attention interfaces * regenerate modular granitemoeshared attention interface * Align granite moe hybrid attn interface * run formatting * Handle mamba initialization * avoid conditional attr defs * Move hybrid layer validation to config * Add placeholder integration tests * Docs nits / Update model names * Clean up forward conditions * Use gradient checkpointing layer * Remove some copied bamba tests + inherit align test init delete more tests Use common layer init with bamba tests finish test consolidation * avoid redundant intermediate std var * use @can_return_tuple * Remove unused moe state * make skipped test names consistent * Fix docstring order * Add missing toc * Always create the shared mlp * Fix name in docstring * link preview model in docs --------- Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> Co-authored-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -47,6 +47,11 @@ if is_torch_available():
|
||||
|
||||
|
||||
class BambaModelTester:
|
||||
config_class = BambaConfig
|
||||
if is_torch_available():
|
||||
model_class = BambaModel
|
||||
for_causal_lm_class = BambaForCausalLM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
@@ -118,6 +123,7 @@ class BambaModelTester:
|
||||
if self.use_labels:
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
|
||||
self._update_layer_configs()
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask, token_labels
|
||||
@@ -133,10 +139,12 @@ class BambaModelTester:
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def get_config(self):
|
||||
def _update_layer_configs(self):
|
||||
"""Configures hidden layers and attn layer indices if they are not set."""
|
||||
# Fix for SDPA tests, force at least 4 layers
|
||||
if self.num_hidden_layers < 4:
|
||||
self.num_hidden_layers = 4
|
||||
|
||||
if self.attn_layer_indices is None:
|
||||
d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0]
|
||||
if len(d) == 0:
|
||||
@@ -144,7 +152,8 @@ class BambaModelTester:
|
||||
d = d[-1] # get the largest divisor
|
||||
self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)]
|
||||
|
||||
return BambaConfig(
|
||||
def get_config(self, **kwargs):
|
||||
return self.config_class(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
@@ -164,6 +173,7 @@ class BambaModelTester:
|
||||
mamba_d_conv=self.mamba_d_conv,
|
||||
mamba_expand=self.mamba_expand,
|
||||
mamba_chunk_size=self.mamba_chunk_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
@@ -173,7 +183,7 @@ class BambaModelTester:
|
||||
input_mask,
|
||||
token_labels,
|
||||
):
|
||||
model = BambaModel(config=config)
|
||||
model = self.model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
@@ -187,7 +197,7 @@ class BambaModelTester:
|
||||
input_mask,
|
||||
token_labels,
|
||||
):
|
||||
model = BambaForCausalLM(config=config)
|
||||
model = self.for_causal_lm_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
@@ -205,7 +215,7 @@ class BambaModelTester:
|
||||
):
|
||||
# config.is_decoder = True
|
||||
# config.add_cross_attention = True
|
||||
model = BambaForCausalLM(config=config)
|
||||
model = self.for_causal_lm_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@@ -258,6 +268,7 @@ class BambaModelTester:
|
||||
|
||||
@require_torch
|
||||
class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
model_tester_class = BambaModelTester
|
||||
all_model_classes = (BambaModel, BambaForCausalLM) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
@@ -276,8 +287,8 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BambaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BambaConfig, hidden_size=64)
|
||||
self.model_tester = self.model_tester_class(self)
|
||||
self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
0
tests/models/granitemoehybrid/__init__.py
Normal file
0
tests/models/granitemoehybrid/__init__.py
Normal file
164
tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py
Normal file
164
tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch GraniteMoeHybrid model."""
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
GraniteMoeHybridConfig,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...models.bamba.test_modeling_bamba import BambaModelTest, BambaModelTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
GraniteMoeHybridForCausalLM,
|
||||
GraniteMoeHybridModel,
|
||||
)
|
||||
|
||||
|
||||
class GraniteMoeHybridModelTester(BambaModelTester):
|
||||
config_class = GraniteMoeHybridConfig
|
||||
if is_torch_available():
|
||||
model_class = GraniteMoeHybridModel
|
||||
for_causal_lm_class = GraniteMoeHybridForCausalLM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
use_cache=False,
|
||||
shared_intermediate_size=174,
|
||||
layer_types=None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.shared_intermediate_size = shared_intermediate_size
|
||||
self.layer_types = layer_types
|
||||
self.use_cache = use_cache
|
||||
|
||||
def _update_layer_configs(self):
|
||||
super()._update_layer_configs()
|
||||
# GraniteMoeHybrid uses layer_types instead of attn_layer_indices
|
||||
self.layer_types = ["mamba"] * self.num_hidden_layers
|
||||
for idx in self.attn_layer_indices:
|
||||
self.layer_types[idx] = "attention"
|
||||
|
||||
def get_config(self):
|
||||
return super().get_config(
|
||||
shared_intermediate_size=self.shared_intermediate_size,
|
||||
layer_types=self.layer_types,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest.TestCase):
|
||||
model_tester_class = GraniteMoeHybridModelTester
|
||||
all_model_classes = (
|
||||
(
|
||||
GraniteMoeHybridModel,
|
||||
GraniteMoeHybridForCausalLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": GraniteMoeHybridModel,
|
||||
"text-generation": GraniteMoeHybridForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
def test_config_requires_mamba_or_attention_layers(self):
|
||||
"""Ensure we can't create a config with disallowed layers."""
|
||||
with pytest.raises(ValueError):
|
||||
GraniteMoeHybridConfig(layer_types=["not allowed!"])
|
||||
|
||||
|
||||
# TODO (@alex-jw-brooks) - update this once the model(s) are out
|
||||
@unittest.skip(reason="GraniteMoeHybrid models are not yet released")
|
||||
@require_torch_gpu
|
||||
class GraniteMoeHybridIntegrationTest(unittest.TestCase):
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@slow
|
||||
def test_model_logits(self):
|
||||
input_ids = [31390, 631, 4162, 30, 322, 25342, 432, 1875, 43826, 10066, 688, 225]
|
||||
|
||||
model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-tiny", device_map="auto")
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(torch.tensor([input_ids]).to(torch_device))
|
||||
|
||||
# fmt: off
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([
|
||||
[-2.9711, -2.2554, -1.0814, -1.6123, -0.8780, -1.0685, -0.6368, -1.9732, -3.3548, -2.6895, -2.3062, -2.6338]
|
||||
])
|
||||
|
||||
torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2)
|
||||
|
||||
# slicing logits[0, 0, 0:15]
|
||||
EXPECTED_SLICE = torch.tensor([
|
||||
[4.0662, 5.9547, 3.5803, 3.1306, 4.3211, 3.8902, 4.6438, 8.5434, 7.5865, 5.1623, 5.2240, 9.2982, 5.9094, 6.8834, 5.7551],
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
EXPECTED_SLICE.to(torch_device),
|
||||
out.logits[0, 0, :15].float(),
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"Simply put, the theory of relativity states that 1) time is relative, and 2) space is relative. The first"
|
||||
)
|
||||
prompt = "Simply put, the theory of relativity states that "
|
||||
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-4.0-tiny")
|
||||
model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-tiny", device_map="auto")
|
||||
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(**model_inputs, max_new_tokens=16, do_sample=False)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
Reference in New Issue
Block a user