Add BloomForSequenceClassification and BloomForTokenClassification classes (#17639)

* add new bloom classes

* (feat) add bloom classification tests; make style

* style: change import in test

* add some typehints to bloom classes

* merge main into branch

* fix: input checking in bloom seq classification

* fix tests

* change model class tests

* fix few tests

- more tests should pass
- one test left

* make token classifier return hidden states

* style: make BLOOM typehints consistent

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Hailey Schoelkopf
2022-06-14 11:10:12 -04:00
committed by GitHub
parent bd43151af4
commit edb672ac5e
8 changed files with 328 additions and 10 deletions

View File

@@ -28,7 +28,14 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attenti
if is_torch_available():
import torch
from transformers import BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, BloomForCausalLM, BloomModel, BloomTokenizerFast
from transformers import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel,
BloomTokenizerFast,
)
@require_torch
@@ -96,9 +103,13 @@ class BloomModelTester:
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
return (config, input_ids, input_mask)
return (config, input_ids, input_mask, sequence_labels)
def get_config(self, gradient_checkpointing=False, slow_but_exact=True):
return BloomConfig(
@@ -116,6 +127,7 @@ class BloomModelTester:
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
num_labels=self.num_labels,
gradient_checkpointing=gradient_checkpointing,
slow_but_exact=slow_but_exact,
dtype="float32",
@@ -245,6 +257,23 @@ class BloomModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_sequence_classification_model(self, config, input_ids, input_mask, *args):
config.num_labels = self.num_labels
model = BloomForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_token_classification_model(self, config, input_ids, input_mask, *args):
model = BloomForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, *args, gradient_checkpointing=False
):
@@ -269,7 +298,7 @@ class BloomModelTester:
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs
config, input_ids, input_mask, sequence_labels = config_and_inputs
inputs_dict = {"input_ids": input_ids}
@@ -279,7 +308,17 @@ class BloomModelTester:
@require_torch
class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BloomModel, BloomForCausalLM) if is_torch_available() else ()
all_model_classes = (
(
BloomModel,
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
fx_compatible = False
test_missing_keys = False
@@ -313,6 +352,14 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
def test_bloom_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_sequence_classification_model(*config_and_inputs)
def test_bloom_token_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_token_classification_model(*config_and_inputs)
def test_bloom_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)