Add BloomForSequenceClassification and BloomForTokenClassification classes (#17639)
* add new bloom classes * (feat) add bloom classification tests; make style * style: change import in test * add some typehints to bloom classes * merge main into branch * fix: input checking in bloom seq classification * fix tests * change model class tests * fix few tests - more tests should pass - one test left * make token classifier return hidden states * style: make BLOOM typehints consistent Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bd43151af4
commit
edb672ac5e
@@ -28,7 +28,14 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attenti
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, BloomForCausalLM, BloomModel, BloomTokenizerFast
|
||||
from transformers import (
|
||||
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BloomForCausalLM,
|
||||
BloomForSequenceClassification,
|
||||
BloomForTokenClassification,
|
||||
BloomModel,
|
||||
BloomTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -96,9 +103,13 @@ class BloomModelTester:
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
sequence_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
|
||||
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
|
||||
|
||||
return (config, input_ids, input_mask)
|
||||
return (config, input_ids, input_mask, sequence_labels)
|
||||
|
||||
def get_config(self, gradient_checkpointing=False, slow_but_exact=True):
|
||||
return BloomConfig(
|
||||
@@ -116,6 +127,7 @@ class BloomModelTester:
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
num_labels=self.num_labels,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
slow_but_exact=slow_but_exact,
|
||||
dtype="float32",
|
||||
@@ -245,6 +257,23 @@ class BloomModelTester:
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_sequence_classification_model(self, config, input_ids, input_mask, *args):
|
||||
config.num_labels = self.num_labels
|
||||
model = BloomForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_token_classification_model(self, config, input_ids, input_mask, *args):
|
||||
model = BloomForTokenClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_forward_and_backwards(
|
||||
self, config, input_ids, input_mask, *args, gradient_checkpointing=False
|
||||
):
|
||||
@@ -269,7 +298,7 @@ class BloomModelTester:
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
config, input_ids, input_mask = config_and_inputs
|
||||
config, input_ids, input_mask, sequence_labels = config_and_inputs
|
||||
|
||||
inputs_dict = {"input_ids": input_ids}
|
||||
|
||||
@@ -279,7 +308,17 @@ class BloomModelTester:
|
||||
@require_torch
|
||||
class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (BloomModel, BloomForCausalLM) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(
|
||||
BloomModel,
|
||||
BloomForCausalLM,
|
||||
BloomForSequenceClassification,
|
||||
BloomForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_missing_keys = False
|
||||
@@ -313,6 +352,14 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_bloom_sequence_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_sequence_classification_model(*config_and_inputs)
|
||||
|
||||
def test_bloom_token_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_token_classification_model(*config_and_inputs)
|
||||
|
||||
def test_bloom_gradient_checkpointing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
|
||||
|
||||
Reference in New Issue
Block a user