Granitemoe (#33207)

* first commit

* drop tokenizer

* drop tokenizer

* drop tokenizer

* drop convert

* granite

* drop tokenization test

* mup

* fix

* reformat

* reformat

* reformat

* fix docs

* stop checking for checkpoint

* update support

* attention multiplier

* update model

* tiny drop

* saibo drop

* skip test

* fix test

* fix test

* drop

* drop useless imports

* update docs

* drop flash function

* copied from

* drop pretraining tp

* drop pretraining tp

* drop pretraining tp

* drop unused import

* drop code path

* change name

* softmax scale

* head dim

* drop legacy cache

* rename params

* cleanup

* fix copies

* comments

* add back legacy cache

* multipliers

* multipliers

* multipliers

* text fix

* fix copies

* merge

* multipliers

* attention multiplier

* drop unused imports

* add granitemoe

* add decoration

* remove moe from sequenceclassification

* fix test

* fix

* fix

* fix

* move rope?

* merge

* drop bias

* drop bias

* Update src/transformers/models/granite/configuration_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* Update src/transformers/models/granite/modeling_granite.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix

* fix

* fix

* drop

* drop

* fix

* fix

* cleanup

* cleanup

* fix

* fix granite tests

* fp32 test

* fix

* drop jitter

* fix

* rename

* rename

* fix config

* add gen test

---------

Co-authored-by: Yikang Shen <yikang.shn@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Mayank Mishra
2024-09-20 19:43:50 -04:00
committed by GitHub
parent 49a0bef4c1
commit e472e077c2
16 changed files with 2393 additions and 58 deletions

View File

@@ -323,61 +323,6 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
# def test_granite_sequence_classification_model(self):
# config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config.num_labels = 3
# input_ids = input_dict["input_ids"]
# attention_mask = input_ids.ne(1).to(torch_device)
# sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
# model = GraniteForSequenceClassification(config)
# model.to(torch_device)
# model.eval()
# result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
# self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# def test_granite_sequence_classification_model_for_single_label(self):
# config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config.num_labels = 3
# config.problem_type = "single_label_classification"
# input_ids = input_dict["input_ids"]
# attention_mask = input_ids.ne(1).to(torch_device)
# sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
# model = GraniteForSequenceClassification(config)
# model.to(torch_device)
# model.eval()
# result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
# self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# def test_granite_sequence_classification_model_for_multi_label(self):
# config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config.num_labels = 3
# config.problem_type = "multi_label_classification"
# input_ids = input_dict["input_ids"]
# attention_mask = input_ids.ne(1).to(torch_device)
# sequence_labels = ids_tensor(
# [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
# ).to(torch.float)
# model = GraniteForSequenceClassification(config)
# model.to(torch_device)
# model.eval()
# result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
# self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# def test_granite_token_classification_model(self):
# config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config.num_labels = 3
# input_ids = input_dict["input_ids"]
# attention_mask = input_ids.ne(1).to(torch_device)
# token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
# model = GraniteForTokenClassification(config=config)
# model.to(torch_device)
# model.eval()
# result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
# self.assertEqual(
# result.logits.shape,
# (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
# )
@unittest.skip("Granite buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass
@@ -581,12 +526,13 @@ class GraniteIntegrationTest(unittest.TestCase):
# Expected mean on dim = -1
# fmt: off
EXPECTED_MEAN = torch.tensor([[-1.8799, -3.1269, -2.8297, -2.3755, -2.7364, -2.2389, -2.5914, -2.4154]])
EXPECTED_MEAN = torch.tensor([[-1.9798, -3.1626, -2.8062, -2.3777, -2.7091, -2.2338, -2.5924, -2.3974]])
self.assertTrue(torch.allclose(EXPECTED_MEAN.to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))
# slicing logits[0, 0, 0:15]
EXPECTED_SLICE = torch.tensor([[4.8125, -2.0156, -2.0156, -2.0000, -2.0000, -2.8438, -2.0156, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000]])
EXPECTED_SLICE = torch.tensor([[4.8750, -2.1875, -2.1875, -2.1875, -2.1875, -2.8438, -2.1875, -2.1875,
-2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875]])
# fmt: on
self.assertTrue(
@@ -610,6 +556,6 @@ class GraniteIntegrationTest(unittest.TestCase):
# fmt: off
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[0.0000, 0.0000, -3.4374, -2.1636, -2.6245, -3.0029, -3.8229, -3.1158]])
EXPECTED_MEAN = torch.tensor([[-2.0984, -3.1294, -2.8153, -2.3568, -2.7337, -2.2624, -2.6016, -2.4022]])
self.assertTrue(torch.allclose(EXPECTED_MEAN.to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))