From 06910f5a764ff247c57ffd0d449eafb025e2bac1 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Tue, 27 Jun 2023 16:07:06 +0200 Subject: [PATCH] [`T5`] Add T5ForQuestionAnswering and MT5ForQuestionAnswering (#24481) * Adding T5ForQuestionAnswering * Changed weight initialization that results in better initial loss when fine-tuning * Update to class variables * Running make fixup * Running make fix-copies * Remove model_parallel * Adding MT5ForQuestionAnswering * Adding docs * Fix wrong doc * Update src/transformers/models/mt5/modeling_mt5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/t5/modeling_t5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * File formatting * Undoing change --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- docs/source/en/model_doc/mt5.md | 4 + docs/source/en/model_doc/t5.md | 5 + docs/source/en/tasks/question_answering.md | 2 +- src/transformers/__init__.py | 12 +- src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/mt5/__init__.py | 10 +- src/transformers/models/mt5/modeling_mt5.py | 204 +++++++++++++++++- src/transformers/models/t5/__init__.py | 2 + src/transformers/models/t5/modeling_t5.py | 198 ++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 14 ++ tests/models/t5/test_modeling_t5.py | 4 +- 11 files changed, 450 insertions(+), 7 deletions(-) diff --git a/docs/source/en/model_doc/mt5.md b/docs/source/en/model_doc/mt5.md index 1e769f5ac1..ea3da76681 100644 --- a/docs/source/en/model_doc/mt5.md +++ b/docs/source/en/model_doc/mt5.md @@ -95,6 +95,10 @@ See [`T5TokenizerFast`] for all details. [[autodoc]] MT5EncoderModel +## MT5ForQuestionAnswering + +[[autodoc]] MT5ForQuestionAnswering + ## TFMT5Model [[autodoc]] TFMT5Model diff --git a/docs/source/en/model_doc/t5.md b/docs/source/en/model_doc/t5.md index b5fa747a81..29068a3a42 100644 --- a/docs/source/en/model_doc/t5.md +++ b/docs/source/en/model_doc/t5.md @@ -399,6 +399,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] T5EncoderModel - forward +## T5ForQuestionAnswering + +[[autodoc]] T5ForQuestionAnswering + - forward + ## TFT5Model [[autodoc]] TFT5Model diff --git a/docs/source/en/tasks/question_answering.md b/docs/source/en/tasks/question_answering.md index cb268e520e..586b0a5d74 100644 --- a/docs/source/en/tasks/question_answering.md +++ b/docs/source/en/tasks/question_answering.md @@ -35,7 +35,7 @@ The task illustrated in this tutorial is supported by the following model archit -[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) +[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d3ea005466..202d9a6101 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2132,7 +2132,7 @@ else: ] ) _import_structure["models.mt5"].extend( - ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model", "MT5PreTrainedModel"] + ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5ForQuestionAnswering", "MT5Model", "MT5PreTrainedModel"] ) _import_structure["models.mvp"].extend( [ @@ -2573,6 +2573,7 @@ else: "T5_PRETRAINED_MODEL_ARCHIVE_LIST", "T5EncoderModel", "T5ForConditionalGeneration", + "T5ForQuestionAnswering", "T5Model", "T5PreTrainedModel", "load_tf_weights_in_t5", @@ -5701,7 +5702,13 @@ if TYPE_CHECKING: MPNetModel, MPNetPreTrainedModel, ) - from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model, MT5PreTrainedModel + from .models.mt5 import ( + MT5EncoderModel, + MT5ForConditionalGeneration, + MT5ForQuestionAnswering, + MT5Model, + MT5PreTrainedModel, + ) from .models.mvp import ( MVP_PRETRAINED_MODEL_ARCHIVE_LIST, MvpForCausalLM, @@ -6064,6 +6071,7 @@ if TYPE_CHECKING: T5_PRETRAINED_MODEL_ARCHIVE_LIST, T5EncoderModel, T5ForConditionalGeneration, + T5ForQuestionAnswering, T5Model, T5PreTrainedModel, load_tf_weights_in_t5, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4259f50374..8c5e5e8c15 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -768,6 +768,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("megatron-bert", "MegatronBertForQuestionAnswering"), ("mobilebert", "MobileBertForQuestionAnswering"), ("mpnet", "MPNetForQuestionAnswering"), + ("mt5", "MT5ForQuestionAnswering"), ("mvp", "MvpForQuestionAnswering"), ("nezha", "NezhaForQuestionAnswering"), ("nystromformer", "NystromformerForQuestionAnswering"), @@ -781,6 +782,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("roformer", "RoFormerForQuestionAnswering"), ("splinter", "SplinterForQuestionAnswering"), ("squeezebert", "SqueezeBertForQuestionAnswering"), + ("t5", "T5ForQuestionAnswering"), ("xlm", "XLMForQuestionAnsweringSimple"), ("xlm-roberta", "XLMRobertaForQuestionAnswering"), ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"), diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py index de966bcb19..d86823021d 100644 --- a/src/transformers/models/mt5/__init__.py +++ b/src/transformers/models/mt5/__init__.py @@ -50,6 +50,7 @@ else: _import_structure["modeling_mt5"] = [ "MT5EncoderModel", "MT5ForConditionalGeneration", + "MT5ForQuestionAnswering", "MT5Model", "MT5PreTrainedModel", "MT5Stack", @@ -81,7 +82,14 @@ if TYPE_CHECKING: except OptionalDependencyNotAvailable: pass else: - from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model, MT5PreTrainedModel, MT5Stack + from .modeling_mt5 import ( + MT5EncoderModel, + MT5ForConditionalGeneration, + MT5ForQuestionAnswering, + MT5Model, + MT5PreTrainedModel, + MT5Stack, + ) try: if not is_tf_available(): diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index a3cfce8ffc..dfeb3c10a9 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -31,6 +31,7 @@ from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -772,12 +773,15 @@ class MT5PreTrainedModel(PreTrainedModel): factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, MT5LayerNorm): module.weight.data.fill_(factor * 1.0) - elif isinstance(module, (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel)): + elif isinstance(module, (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel, MT5ForQuestionAnswering)): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "qa_outputs"): + module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.data.zero_() elif isinstance(module, MT5DenseActDense): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 @@ -2015,3 +2019,201 @@ class MT5EncoderModel(MT5PreTrainedModel): ) return encoder_outputs + + +@add_start_docstrings( + """ + MT5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers + on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MT5_START_DOCSTRING, +) +class MT5ForQuestionAnswering(MT5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder.embed_tokens.weight", + r"decoder.embed_tokens.weight", + r"lm_head.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.__init__ with T5->MT5 + def __init__(self, config: MT5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = MT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = MT5Stack(decoder_config, self.shared) + + self.num_labels = config.num_labels + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_input_embeddings + def get_input_embeddings(self): + return self.shared + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.set_input_embeddings + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder + def get_encoder(self): + return self.encoder + + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_decoder + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.forward + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + if start_positions is not None and end_positions is not None: + use_cache = False + + # Copied from models.bart.modeling_bart.BartModel.forward + # different to other models, T5 automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=None, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) diff --git a/src/transformers/models/t5/__init__.py b/src/transformers/models/t5/__init__.py index 396160d773..5a02877504 100644 --- a/src/transformers/models/t5/__init__.py +++ b/src/transformers/models/t5/__init__.py @@ -56,6 +56,7 @@ else: "T5Model", "T5PreTrainedModel", "load_tf_weights_in_t5", + "T5ForQuestionAnswering", ] try: @@ -115,6 +116,7 @@ if TYPE_CHECKING: T5_PRETRAINED_MODEL_ARCHIVE_LIST, T5EncoderModel, T5ForConditionalGeneration, + T5ForQuestionAnswering, T5Model, T5PreTrainedModel, load_tf_weights_in_t5, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 050309fa9a..f1f7b17c7b 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -32,6 +32,7 @@ from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer @@ -801,12 +802,15 @@ class T5PreTrainedModel(PreTrainedModel): factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, T5LayerNorm): module.weight.data.fill_(factor * 1.0) - elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering)): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "qa_outputs"): + module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.data.zero_() elif isinstance(module, T5DenseActDense): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 @@ -1949,3 +1953,195 @@ class T5EncoderModel(T5PreTrainedModel): ) return encoder_outputs + + +@add_start_docstrings( + """ + T5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers + on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + T5_START_DOCSTRING, +) +class T5ForQuestionAnswering(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder.embed_tokens.weight", + r"decoder.embed_tokens.weight", + r"lm_head.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.num_labels = config.num_labels + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + if start_positions is not None and end_positions is not None: + use_cache = False + + # Copied from models.bart.modeling_bart.BartModel.forward + # different to other models, T5 automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=None, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index f07b2fdd87..4513032607 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4915,6 +4915,13 @@ class MT5ForConditionalGeneration(metaclass=DummyObject): requires_backends(self, ["torch"]) +class MT5ForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MT5Model(metaclass=DummyObject): _backends = ["torch"] @@ -6742,6 +6749,13 @@ class T5ForConditionalGeneration(metaclass=DummyObject): requires_backends(self, ["torch"]) +class T5ForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class T5Model(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 3bc4d03a9c..b0ff82dd62 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -43,6 +43,7 @@ if is_torch_available(): ByT5Tokenizer, T5EncoderModel, T5ForConditionalGeneration, + T5ForQuestionAnswering, T5Model, T5Tokenizer, ) @@ -520,7 +521,7 @@ class T5ModelTester: @require_torch class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () + all_model_classes = (T5Model, T5ForConditionalGeneration, T5ForQuestionAnswering) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = ( { @@ -529,6 +530,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, "summarization": T5ForConditionalGeneration, "text2text-generation": T5ForConditionalGeneration, "translation": T5ForConditionalGeneration, + "question-answering": T5ForQuestionAnswering, } if is_torch_available() else {}