diff --git a/docs/source/en/model_doc/opt.mdx b/docs/source/en/model_doc/opt.mdx index 04344df56d..4ab9436b04 100644 --- a/docs/source/en/model_doc/opt.mdx +++ b/docs/source/en/model_doc/opt.mdx @@ -54,6 +54,11 @@ The original code can be found [here](https://github.com/facebookresearch/metase [[autodoc]] TFOPTForCausalLM - call +## OPTForSequenceClassification + +[[autodoc]] OPTForSequenceClassification + - forward + ## FlaxOPTModel [[autodoc]] FlaxOPTModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index be15f1f2aa..5d6227c04a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1504,6 +1504,7 @@ else: "OPTForCausalLM", "OPTModel", "OPTPreTrainedModel", + "OPTForSequenceClassification", ] ) _import_structure["models.pegasus"].extend( @@ -4026,7 +4027,13 @@ if TYPE_CHECKING: OpenAIGPTPreTrainedModel, load_tf_weights_in_openai_gpt, ) - from .models.opt import OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPTForCausalLM, OPTModel, OPTPreTrainedModel + from .models.opt import ( + OPT_PRETRAINED_MODEL_ARCHIVE_LIST, + OPTForCausalLM, + OPTForSequenceClassification, + OPTModel, + OPTPreTrainedModel, + ) from .models.pegasus import ( PegasusForCausalLM, PegasusForConditionalGeneration, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 3cd574dea1..1c9cf86fbf 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -503,6 +503,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("nezha", "NezhaForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"), ("openai-gpt", "OpenAIGPTForSequenceClassification"), + ("opt", "OPTForSequenceClassification"), ("perceiver", "PerceiverForSequenceClassification"), ("plbart", "PLBartForSequenceClassification"), ("qdqbert", "QDQBertForSequenceClassification"), diff --git a/src/transformers/models/opt/__init__.py b/src/transformers/models/opt/__init__.py index e35d07d1b0..4e55086409 100644 --- a/src/transformers/models/opt/__init__.py +++ b/src/transformers/models/opt/__init__.py @@ -40,6 +40,7 @@ else: "OPTForCausalLM", "OPTModel", "OPTPreTrainedModel", + "OPTForSequenceClassification", ] try: @@ -72,7 +73,13 @@ if TYPE_CHECKING: except OptionalDependencyNotAvailable: pass else: - from .modeling_opt import OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPTForCausalLM, OPTModel, OPTPreTrainedModel + from .modeling_opt import ( + OPT_PRETRAINED_MODEL_ARCHIVE_LIST, + OPTForCausalLM, + OPTForSequenceClassification, + OPTModel, + OPTPreTrainedModel, + ) try: if not is_tf_available(): diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index dc43d1cc08..2284867593 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -19,10 +19,10 @@ from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, @@ -383,6 +383,7 @@ OPT_START_DOCSTRING = r""" OPT_START_DOCSTRING, ) class OPTPreTrainedModel(PreTrainedModel): + config_class = OPTConfig base_model_prefix = "model" supports_gradient_checkpointing = True @@ -729,7 +730,6 @@ class OPTModel(OPTPreTrainedModel): def __init__(self, config: OPTConfig): super().__init__(config) self.decoder = OPTDecoder(config) - # Initialize weights and apply final processing self.post_init() @@ -976,3 +976,133 @@ class OPTForCausalLM(OPTPreTrainedModel): for layer_past in past: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past + + +@add_start_docstrings( + """ + The OPT Model transformer with a sequence classification head on top (linear layer). + + [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPT_START_DOCSTRING, +) +class OPTForSequenceClassification(OPTPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OPTModel(config) + self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output="'LABEL_0'", + expected_loss=5.28, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 3f75b7085f..b1e97d3acf 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3438,6 +3438,13 @@ class OPTForCausalLM(metaclass=DummyObject): requires_backends(self, ["torch"]) +class OPTForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class OPTModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 86fa8c2c4b..bdf3716b59 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -32,7 +32,7 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor if is_torch_available(): import torch - from transformers import GPT2Tokenizer, OPTForCausalLM, OPTModel + from transformers import GPT2Tokenizer, OPTForCausalLM, OPTForSequenceClassification, OPTModel def prepare_opt_inputs_dict( @@ -74,7 +74,9 @@ class OPTModelTester: pad_token_id=1, bos_token_id=0, embed_dim=16, + num_labels=3, word_embed_proj_dim=16, + type_sequence_label_size=2, ): self.parent = parent self.batch_size = batch_size @@ -94,11 +96,12 @@ class OPTModelTester: self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.embed_dim = embed_dim + self.num_labels = num_labels + self.type_sequence_label_size = type_sequence_label_size self.word_embed_proj_dim = word_embed_proj_dim self.is_encoder_decoder = False def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( 3, ) @@ -175,7 +178,7 @@ class OPTModelTester: @require_torch class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else () + all_model_classes = (OPTModel, OPTForCausalLM, OPTForSequenceClassification) if is_torch_available() else () all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else () is_encoder_decoder = False fx_compatible = True @@ -242,6 +245,33 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + def test_opt_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = OPTForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_opt_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = OPTForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""