Add BloomForSequenceClassification and BloomForTokenClassification classes (#17639)

* add new bloom classes

* (feat) add bloom classification tests; make style

* style: change import in test

* add some typehints to bloom classes

* merge main into branch

* fix: input checking in bloom seq classification

* fix tests

* change model class tests

* fix few tests

- more tests should pass
- one test left

* make token classifier return hidden states

* style: make BLOOM typehints consistent

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Hailey Schoelkopf
2022-06-14 11:10:12 -04:00
committed by GitHub
parent bd43151af4
commit edb672ac5e
8 changed files with 328 additions and 10 deletions

View File

@@ -45,3 +45,13 @@ Several smaller versions of the models have been trained on the same dataset. BL
[[autodoc]] BloomForCausalLM [[autodoc]] BloomForCausalLM
- forward - forward
## BloomForSequenceClassification
[[autodoc]] BloomForSequenceClassification
- forward
## BloomForTokenClassification
[[autodoc]] BloomForTokenClassification
- forward

View File

@@ -870,6 +870,8 @@ else:
"BloomForCausalLM", "BloomForCausalLM",
"BloomModel", "BloomModel",
"BloomPreTrainedModel", "BloomPreTrainedModel",
"BloomForSequenceClassification",
"BloomForTokenClassification",
] ]
) )
_import_structure["models.blenderbot"].extend( _import_structure["models.blenderbot"].extend(
@@ -3417,6 +3419,8 @@ if TYPE_CHECKING:
from .models.bloom import ( from .models.bloom import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM, BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel, BloomModel,
BloomPreTrainedModel, BloomPreTrainedModel,
) )

View File

@@ -453,6 +453,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("bert", "BertForSequenceClassification"), ("bert", "BertForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"), ("big_bird", "BigBirdForSequenceClassification"),
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
("bloom", "BloomForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"), ("camembert", "CamembertForSequenceClassification"),
("canine", "CanineForSequenceClassification"), ("canine", "CanineForSequenceClassification"),
("convbert", "ConvBertForSequenceClassification"), ("convbert", "ConvBertForSequenceClassification"),
@@ -563,6 +564,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("albert", "AlbertForTokenClassification"), ("albert", "AlbertForTokenClassification"),
("bert", "BertForTokenClassification"), ("bert", "BertForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"), ("big_bird", "BigBirdForTokenClassification"),
("bloom", "BloomForTokenClassification"),
("camembert", "CamembertForTokenClassification"), ("camembert", "CamembertForTokenClassification"),
("canine", "CanineForTokenClassification"), ("canine", "CanineForTokenClassification"),
("convbert", "ConvBertForTokenClassification"), ("convbert", "ConvBertForTokenClassification"),

View File

@@ -46,6 +46,8 @@ else:
"BloomForCausalLM", "BloomForCausalLM",
"BloomModel", "BloomModel",
"BloomPreTrainedModel", "BloomPreTrainedModel",
"BloomForSequenceClassification",
"BloomForTokenClassification",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -68,6 +70,8 @@ if TYPE_CHECKING:
from .modeling_bloom import ( from .modeling_bloom import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM, BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel, BloomModel,
BloomPreTrainedModel, BloomPreTrainedModel,
) )

View File

@@ -15,15 +15,20 @@
"""PyTorch BLOOM model.""" """PyTorch BLOOM model."""
import math import math
from typing import Tuple from typing import Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_bloom import BloomConfig from .configuration_bloom import BloomConfig
@@ -42,7 +47,7 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bigscience/bloom-1b3", "bigscience/bloom-1b3",
"bigscience/bloom-2b5", "bigscience/bloom-2b5",
"bigscience/bloom-6b3", "bigscience/bloom-6b3",
"bigscience/bloom-176b", "bigscience/bloom",
] ]
@@ -726,7 +731,7 @@ class BloomModel(BloomPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -902,7 +907,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -959,3 +964,223 @@ class BloomForCausalLM(BloomPreTrainedModel):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past for layer_past in past
) )
@add_start_docstrings(
"""
The Bloom Model transformer with a sequence classification head on top (linear layer).
[`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
BLOOM_START_DOCSTRING,
)
class BloomForSequenceClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = BloomModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
Bloom Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BLOOM_START_DOCSTRING,
)
class BloomForTokenClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = BloomModel(config)
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
classifier_dropout = config.classifier_dropout
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

View File

@@ -558,6 +558,18 @@ class Trainer:
) )
self.use_apex = True self.use_apex = True
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
if (
is_sagemaker_mp_enabled()
and self.use_cuda_amp
and args.max_grad_norm is not None
and args.max_grad_norm > 0
):
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
)
# Label smoothing # Label smoothing
if self.args.label_smoothing_factor != 0: if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)

View File

@@ -986,6 +986,20 @@ class BloomForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class BloomForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BloomForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BloomModel(metaclass=DummyObject): class BloomModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]

View File

@@ -28,7 +28,14 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attenti
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, BloomForCausalLM, BloomModel, BloomTokenizerFast from transformers import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel,
BloomTokenizerFast,
)
@require_torch @require_torch
@@ -96,9 +103,13 @@ class BloomModelTester:
if self.use_input_mask: if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length]) input_mask = random_attention_mask([self.batch_size, self.seq_length])
sequence_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config(gradient_checkpointing=gradient_checkpointing) config = self.get_config(gradient_checkpointing=gradient_checkpointing)
return (config, input_ids, input_mask) return (config, input_ids, input_mask, sequence_labels)
def get_config(self, gradient_checkpointing=False, slow_but_exact=True): def get_config(self, gradient_checkpointing=False, slow_but_exact=True):
return BloomConfig( return BloomConfig(
@@ -116,6 +127,7 @@ class BloomModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
num_labels=self.num_labels,
gradient_checkpointing=gradient_checkpointing, gradient_checkpointing=gradient_checkpointing,
slow_but_exact=slow_but_exact, slow_but_exact=slow_but_exact,
dtype="float32", dtype="float32",
@@ -245,6 +257,23 @@ class BloomModelTester:
self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_sequence_classification_model(self, config, input_ids, input_mask, *args):
config.num_labels = self.num_labels
model = BloomForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_token_classification_model(self, config, input_ids, input_mask, *args):
model = BloomForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_forward_and_backwards( def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, *args, gradient_checkpointing=False self, config, input_ids, input_mask, *args, gradient_checkpointing=False
): ):
@@ -269,7 +298,7 @@ class BloomModelTester:
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs config, input_ids, input_mask, sequence_labels = config_and_inputs
inputs_dict = {"input_ids": input_ids} inputs_dict = {"input_ids": input_ids}
@@ -279,7 +308,17 @@ class BloomModelTester:
@require_torch @require_torch
class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BloomModel, BloomForCausalLM) if is_torch_available() else () all_model_classes = (
(
BloomModel,
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else () all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
fx_compatible = False fx_compatible = False
test_missing_keys = False test_missing_keys = False
@@ -313,6 +352,14 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs) self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
def test_bloom_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_sequence_classification_model(*config_and_inputs)
def test_bloom_token_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_token_classification_model(*config_and_inputs)
def test_bloom_gradient_checkpointing(self): def test_bloom_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)