From af69360bf96946e6d4d321d96ec7dea991c5cbbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?APAVOU=20Cl=C3=A9ment?= <45719060+clementapa@users.noreply.github.com> Date: Mon, 10 Oct 2022 15:30:59 +0200 Subject: [PATCH] Add `OPTForQuestionAnswering` (#19402) * Add `OPTForQuestionAnswering` - added `OPTForQuestionAnswering` class based on `BloomForQuestionAnswering` - added `OPTForQuestionAnswering` in common tests - all common tests pass - make fixup done * added docstrings for OPTForQuestionAnswering * Fix docstrings for OPTForQuestionAnswering --- docs/source/en/model_doc/opt.mdx | 5 + src/transformers/__init__.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/opt/__init__.py | 2 + src/transformers/models/opt/modeling_opt.py | 121 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 7 + tests/models/opt/test_modeling_opt.py | 14 +- 7 files changed, 149 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/opt.mdx b/docs/source/en/model_doc/opt.mdx index 4ab9436b04..612689678f 100644 --- a/docs/source/en/model_doc/opt.mdx +++ b/docs/source/en/model_doc/opt.mdx @@ -59,6 +59,11 @@ The original code can be found [here](https://github.com/facebookresearch/metase [[autodoc]] OPTForSequenceClassification - forward +## OPTForQuestionAnswering + +[[autodoc]] OPTForQuestionAnswering + - forward + ## FlaxOPTModel [[autodoc]] FlaxOPTModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 370a347c9f..0fd95663a5 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1661,6 +1661,7 @@ else: "OPTModel", "OPTPreTrainedModel", "OPTForSequenceClassification", + "OPTForQuestionAnswering", ] ) _import_structure["models.owlvit"].extend( @@ -4408,6 +4409,7 @@ if TYPE_CHECKING: from .models.opt import ( OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPTForCausalLM, + OPTForQuestionAnswering, OPTForSequenceClassification, OPTModel, OPTPreTrainedModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 237c98c5bb..edd61e1da9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -611,6 +611,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("mvp", "MvpForQuestionAnswering"), ("nezha", "NezhaForQuestionAnswering"), ("nystromformer", "NystromformerForQuestionAnswering"), + ("opt", "OPTForQuestionAnswering"), ("qdqbert", "QDQBertForQuestionAnswering"), ("reformer", "ReformerForQuestionAnswering"), ("rembert", "RemBertForQuestionAnswering"), diff --git a/src/transformers/models/opt/__init__.py b/src/transformers/models/opt/__init__.py index 4e55086409..c5a4533c03 100644 --- a/src/transformers/models/opt/__init__.py +++ b/src/transformers/models/opt/__init__.py @@ -41,6 +41,7 @@ else: "OPTModel", "OPTPreTrainedModel", "OPTForSequenceClassification", + "OPTForQuestionAnswering", ] try: @@ -76,6 +77,7 @@ if TYPE_CHECKING: from .modeling_opt import ( OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPTForCausalLM, + OPTForQuestionAnswering, OPTForSequenceClassification, OPTModel, OPTPreTrainedModel, diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 9ede3cabb8..acea546a20 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -22,7 +22,12 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, @@ -48,6 +53,11 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" _SEQ_CLASS_EXPECTED_LOSS = 1.71 _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" +# QuestionAnswering docstring +_QA_EXPECTED_OUTPUT = "'a nice puppet'" +_QA_EXPECTED_LOSS = 7.41 +_QA_TARGET_START_INDEX = 14 +_QA_TARGET_END_INDEX = 15 OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ "facebook/opt-125m", @@ -1109,3 +1119,112 @@ class OPTForSequenceClassification(OPTPreTrainedModel): def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value + + +@add_start_docstrings( + """ + The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD + (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + OPT_START_DOCSTRING, +) +class OPTForQuestionAnswering(OPTPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.model = OPTModel(config) + self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + qa_target_start_index=_QA_TARGET_START_INDEX, + qa_target_end_index=_QA_TARGET_END_INDEX, + expected_output=_QA_EXPECTED_OUTPUT, + expected_loss=_QA_EXPECTED_LOSS, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + transformer_outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 72db36cab9..1546700c5c 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3714,6 +3714,13 @@ class OPTForCausalLM(metaclass=DummyObject): requires_backends(self, ["torch"]) +class OPTForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class OPTForSequenceClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index bdf3716b59..6b9311b09b 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -32,7 +32,13 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor if is_torch_available(): import torch - from transformers import GPT2Tokenizer, OPTForCausalLM, OPTForSequenceClassification, OPTModel + from transformers import ( + GPT2Tokenizer, + OPTForCausalLM, + OPTForQuestionAnswering, + OPTForSequenceClassification, + OPTModel, + ) def prepare_opt_inputs_dict( @@ -178,7 +184,11 @@ class OPTModelTester: @require_torch class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (OPTModel, OPTForCausalLM, OPTForSequenceClassification) if is_torch_available() else () + all_model_classes = ( + (OPTModel, OPTForCausalLM, OPTForSequenceClassification, OPTForQuestionAnswering) + if is_torch_available() + else () + ) all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else () is_encoder_decoder = False fx_compatible = True