From 62098b93481a7ed1b54ee9e33467534d91758375 Mon Sep 17 00:00:00 2001 From: Ikuya Yamada Date: Tue, 2 Aug 2022 00:09:47 +0900 Subject: [PATCH] Adding fine-tuning models to LUKE (#18353) * add LUKE models for downstream tasks * add new LUKE models to docs * fix typos * remove commented lines * exclude None items from tuple return values --- docs/source/en/model_doc/luke.mdx | 20 + src/transformers/__init__.py | 8 + src/transformers/models/auto/modeling_auto.py | 6 + src/transformers/models/luke/__init__.py | 8 + .../models/luke/configuration_luke.py | 4 + src/transformers/models/luke/modeling_luke.py | 653 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 28 + tests/models/luke/test_modeling_luke.py | 233 ++++++- 8 files changed, 922 insertions(+), 38 deletions(-) diff --git a/docs/source/en/model_doc/luke.mdx b/docs/source/en/model_doc/luke.mdx index 6900367bb8..b7483f9194 100644 --- a/docs/source/en/model_doc/luke.mdx +++ b/docs/source/en/model_doc/luke.mdx @@ -152,3 +152,23 @@ This model was contributed by [ikuyamada](https://huggingface.co/ikuyamada) and [[autodoc]] LukeForEntitySpanClassification - forward + +## LukeForSequenceClassification + +[[autodoc]] LukeForSequenceClassification + - forward + +## LukeForMultipleChoice + +[[autodoc]] LukeForMultipleChoice + - forward + +## LukeForTokenClassification + +[[autodoc]] LukeForTokenClassification + - forward + +## LukeForQuestionAnswering + +[[autodoc]] LukeForQuestionAnswering + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1c0d86e567..75784ce463 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1363,6 +1363,10 @@ else: "LukeForEntityClassification", "LukeForEntityPairClassification", "LukeForEntitySpanClassification", + "LukeForMultipleChoice", + "LukeForQuestionAnswering", + "LukeForSequenceClassification", + "LukeForTokenClassification", "LukeForMaskedLM", "LukeModel", "LukePreTrainedModel", @@ -3953,6 +3957,10 @@ if TYPE_CHECKING: LukeForEntityPairClassification, LukeForEntitySpanClassification, LukeForMaskedLM, + LukeForMultipleChoice, + LukeForQuestionAnswering, + LukeForSequenceClassification, + LukeForTokenClassification, LukeModel, LukePreTrainedModel, ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ce2c4f9457..a86e8bc56d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -170,6 +170,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ("ibert", "IBertForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"), ("longformer", "LongformerForMaskedLM"), + ("luke", "LukeForMaskedLM"), ("lxmert", "LxmertForPreTraining"), ("megatron-bert", "MegatronBertForPreTraining"), ("mobilebert", "MobileBertForPreTraining"), @@ -230,6 +231,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ("led", "LEDForConditionalGeneration"), ("longformer", "LongformerForMaskedLM"), ("longt5", "LongT5ForConditionalGeneration"), + ("luke", "LukeForMaskedLM"), ("m2m_100", "M2M100ForConditionalGeneration"), ("marian", "MarianMTModel"), ("megatron-bert", "MegatronBertForCausalLM"), @@ -499,6 +501,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("layoutlmv3", "LayoutLMv3ForSequenceClassification"), ("led", "LEDForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), + ("luke", "LukeForSequenceClassification"), ("mbart", "MBartForSequenceClassification"), ("megatron-bert", "MegatronBertForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"), @@ -551,6 +554,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), ("led", "LEDForQuestionAnswering"), ("longformer", "LongformerForQuestionAnswering"), + ("luke", "LukeForQuestionAnswering"), ("lxmert", "LxmertForQuestionAnswering"), ("mbart", "MBartForQuestionAnswering"), ("megatron-bert", "MegatronBertForQuestionAnswering"), @@ -611,6 +615,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("longformer", "LongformerForTokenClassification"), + ("luke", "LukeForTokenClassification"), ("megatron-bert", "MegatronBertForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), ("mpnet", "MPNetForTokenClassification"), @@ -647,6 +652,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( ("funnel", "FunnelForMultipleChoice"), ("ibert", "IBertForMultipleChoice"), ("longformer", "LongformerForMultipleChoice"), + ("luke", "LukeForMultipleChoice"), ("megatron-bert", "MegatronBertForMultipleChoice"), ("mobilebert", "MobileBertForMultipleChoice"), ("mpnet", "MPNetForMultipleChoice"), diff --git a/src/transformers/models/luke/__init__.py b/src/transformers/models/luke/__init__.py index 36ca833aaa..42165923b1 100644 --- a/src/transformers/models/luke/__init__.py +++ b/src/transformers/models/luke/__init__.py @@ -37,6 +37,10 @@ else: "LukeForEntityClassification", "LukeForEntityPairClassification", "LukeForEntitySpanClassification", + "LukeForMultipleChoice", + "LukeForQuestionAnswering", + "LukeForSequenceClassification", + "LukeForTokenClassification", "LukeForMaskedLM", "LukeModel", "LukePreTrainedModel", @@ -59,6 +63,10 @@ if TYPE_CHECKING: LukeForEntityPairClassification, LukeForEntitySpanClassification, LukeForMaskedLM, + LukeForMultipleChoice, + LukeForQuestionAnswering, + LukeForSequenceClassification, + LukeForTokenClassification, LukeModel, LukePreTrainedModel, ) diff --git a/src/transformers/models/luke/configuration_luke.py b/src/transformers/models/luke/configuration_luke.py index c5a7a8f581..8f7438cc3c 100644 --- a/src/transformers/models/luke/configuration_luke.py +++ b/src/transformers/models/luke/configuration_luke.py @@ -74,6 +74,8 @@ class LukeConfig(PretrainedConfig): Whether or not the model should use the entity-aware self-attention mechanism proposed in [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.)](https://arxiv.org/abs/2010.01057). + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. Examples: @@ -108,6 +110,7 @@ class LukeConfig(PretrainedConfig): initializer_range=0.02, layer_norm_eps=1e-12, use_entity_aware_attention=True, + classifier_dropout=None, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -131,3 +134,4 @@ class LukeConfig(PretrainedConfig): self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.use_entity_aware_attention = use_entity_aware_attention + self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index cca68cd535..6d40dfafe8 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -21,6 +21,7 @@ from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling @@ -28,6 +29,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( ModelOutput, + add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, @@ -247,6 +249,147 @@ class EntitySpanClassificationOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class LukeSequenceClassifierOutput(ModelOutput): + """ + Outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class LukeTokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class LukeQuestionAnsweringModelOutput(ModelOutput): + """ + Outputs of question answering models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: torch.FloatTensor = None + end_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class LukeMultipleChoiceModelOutput(ModelOutput): + """ + Outputs of multiple choice models. + + Args: + loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): + Classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + *num_choices* is the second dimension of the input tensors. (see *input_ids* above). + + Classification scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each + layer plus the initial entity embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + class LukeEmbeddings(nn.Module): """ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. @@ -1240,15 +1383,20 @@ class LukeForMaskedLM(LukePreTrainedModel): loss = loss + mep_loss if not return_dict: - output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions) - if mlm_loss is not None and mep_loss is not None: - return (loss, mlm_loss, mep_loss) + output - elif mlm_loss is not None: - return (loss, mlm_loss) + output - elif mep_loss is not None: - return (loss, mep_loss) + output - else: - return output + return tuple( + v + for v in [ + loss, + mlm_loss, + mep_loss, + logits, + entity_logits, + outputs.hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ] + if v is not None + ) return LukeMaskedLMOutput( loss=loss, @@ -1360,13 +1508,11 @@ class LukeForEntityClassification(LukePreTrainedModel): loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) if not return_dict: - output = ( - logits, - outputs.hidden_states, - outputs.entity_hidden_states, - outputs.attentions, + return tuple( + v + for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions] + if v is not None ) - return ((loss,) + output) if loss is not None else output return EntityClassificationOutput( loss=loss, @@ -1480,13 +1626,11 @@ class LukeForEntityPairClassification(LukePreTrainedModel): loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) if not return_dict: - output = ( - logits, - outputs.hidden_states, - outputs.entity_hidden_states, - outputs.attentions, + return tuple( + v + for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions] + if v is not None ) - return ((loss,) + output) if loss is not None else output return EntityPairClassificationOutput( loss=loss, @@ -1620,13 +1764,11 @@ class LukeForEntitySpanClassification(LukePreTrainedModel): loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) if not return_dict: - output = ( - logits, - outputs.hidden_states, - outputs.entity_hidden_states, - outputs.attentions, + return tuple( + v + for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions] + if v is not None ) - return ((loss,) + output) if loss is not None else output return EntitySpanClassificationOutput( loss=loss, @@ -1635,3 +1777,460 @@ class LukeForEntitySpanClassification(LukePreTrainedModel): entity_hidden_states=outputs.entity_hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ + The LUKE Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + LUKE_START_DOCSTRING, +) +class LukeForSequenceClassification(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.luke = LukeModel(config) + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=LukeSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, LukeSequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + pooled_output = outputs.pooler_output + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + return tuple( + v + for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions] + if v is not None + ) + + return LukeSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The LUKE Model with a token classification head on top (a linear layer on top of the hidden-states output). To + solve Named-Entity Recognition (NER) task using LUKE, `LukeForEntitySpanClassification` is more suitable than this + class. + """, + LUKE_START_DOCSTRING, +) +class LukeForTokenClassification(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.luke = LukeModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=LukeTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, LukeTokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + sequence_output = outputs.last_hidden_state + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + return tuple( + v + for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions] + if v is not None + ) + + return LukeTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The LUKE Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LUKE_START_DOCSTRING, +) +class LukeForQuestionAnswering(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + + self.luke = LukeModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=LukeQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.FloatTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, LukeQuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + return tuple( + v + for v in [ + total_loss, + start_logits, + end_logits, + outputs.hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ] + if v is not None + ) + + return LukeQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The LUKE Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + LUKE_START_DOCSTRING, +) +class LukeForMultipleChoice(LukePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.luke = LukeModel(config) + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=LukeMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + entity_ids: Optional[torch.LongTensor] = None, + entity_attention_mask: Optional[torch.FloatTensor] = None, + entity_token_type_ids: Optional[torch.LongTensor] = None, + entity_position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, LukeMultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + entity_ids = entity_ids.view(-1, entity_ids.size(-1)) if entity_ids is not None else None + entity_attention_mask = ( + entity_attention_mask.view(-1, entity_attention_mask.size(-1)) + if entity_attention_mask is not None + else None + ) + entity_token_type_ids = ( + entity_token_type_ids.view(-1, entity_token_type_ids.size(-1)) + if entity_token_type_ids is not None + else None + ) + entity_position_ids = ( + entity_position_ids.view(-1, entity_position_ids.size(-2), entity_position_ids.size(-1)) + if entity_position_ids is not None + else None + ) + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + pooled_output = outputs.pooler_output + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + return tuple( + v + for v in [ + loss, + reshaped_logits, + outputs.hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ] + if v is not None + ) + + return LukeMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c08610921f..c1dfc6b6b7 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2736,6 +2736,34 @@ class LukeForMaskedLM(metaclass=DummyObject): requires_backends(self, ["torch"]) +class LukeForMultipleChoice(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LukeForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LukeForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LukeForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LukeModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/luke/test_modeling_luke.py b/tests/models/luke/test_modeling_luke.py index 264b7f8955..789988d5ca 100644 --- a/tests/models/luke/test_modeling_luke.py +++ b/tests/models/luke/test_modeling_luke.py @@ -30,6 +30,10 @@ if is_torch_available(): LukeForEntityPairClassification, LukeForEntitySpanClassification, LukeForMaskedLM, + LukeForMultipleChoice, + LukeForQuestionAnswering, + LukeForSequenceClassification, + LukeForTokenClassification, LukeModel, LukeTokenizer, ) @@ -66,6 +70,8 @@ class LukeModelTester: type_vocab_size=16, type_sequence_label_size=2, initializer_range=0.02, + num_labels=3, + num_choices=4, num_entity_classification_labels=9, num_entity_pair_classification_labels=6, num_entity_span_classification_labels=4, @@ -99,6 +105,8 @@ class LukeModelTester: self.type_vocab_size = type_vocab_size self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices self.num_entity_classification_labels = num_entity_classification_labels self.num_entity_pair_classification_labels = num_entity_pair_classification_labels self.num_entity_span_classification_labels = num_entity_span_classification_labels @@ -139,7 +147,8 @@ class LukeModelTester: ) sequence_labels = None - labels = None + token_labels = None + choice_labels = None entity_labels = None entity_classification_labels = None entity_pair_classification_labels = None @@ -147,7 +156,9 @@ class LukeModelTester: if self.use_labels: sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size) entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels) @@ -170,7 +181,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, - labels, + token_labels, + choice_labels, entity_labels, entity_classification_labels, entity_pair_classification_labels, @@ -207,7 +219,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, - labels, + token_labels, + choice_labels, entity_labels, entity_classification_labels, entity_pair_classification_labels, @@ -247,7 +260,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, - labels, + token_labels, + choice_labels, entity_labels, entity_classification_labels, entity_pair_classification_labels, @@ -266,7 +280,7 @@ class LukeModelTester: entity_attention_mask=entity_attention_mask, entity_token_type_ids=entity_token_type_ids, entity_position_ids=entity_position_ids, - labels=labels, + labels=token_labels, entity_labels=entity_labels, ) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) @@ -288,7 +302,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, - labels, + token_labels, + choice_labels, entity_labels, entity_classification_labels, entity_pair_classification_labels, @@ -322,7 +337,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, - labels, + token_labels, + choice_labels, entity_labels, entity_classification_labels, entity_pair_classification_labels, @@ -356,7 +372,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, - labels, + token_labels, + choice_labels, entity_labels, entity_classification_labels, entity_pair_classification_labels, @@ -386,6 +403,156 @@ class LukeModelTester: result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels) ) + def create_and_check_for_question_answering( + self, + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + token_labels, + choice_labels, + entity_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ): + model = LukeForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) + self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) + + def create_and_check_for_sequence_classification( + self, + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + token_labels, + choice_labels, + entity_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ): + config.num_labels = self.num_labels + model = LukeForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + labels=sequence_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_for_token_classification( + self, + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + token_labels, + choice_labels, + entity_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ): + config.num_labels = self.num_labels + model = LukeForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + labels=token_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def create_and_check_for_multiple_choice( + self, + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + token_labels, + choice_labels, + entity_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ): + config.num_choices = self.num_choices + model = LukeForMultipleChoice(config=config) + model.to(torch_device) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_attention_mask = attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_entity_ids = entity_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_entity_token_type_ids = ( + entity_token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + ) + multiple_choice_entity_attention_mask = ( + entity_attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + ) + multiple_choice_entity_position_ids = ( + entity_position_ids.unsqueeze(1).expand(-1, self.num_choices, -1, -1).contiguous() + ) + result = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_attention_mask, + token_type_ids=multiple_choice_token_type_ids, + entity_ids=multiple_choice_entity_ids, + entity_attention_mask=multiple_choice_entity_attention_mask, + entity_token_type_ids=multiple_choice_entity_token_type_ids, + entity_position_ids=multiple_choice_entity_position_ids, + labels=choice_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -398,7 +565,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, - labels, + token_labels, + choice_labels, entity_labels, entity_classification_labels, entity_pair_classification_labels, @@ -426,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): LukeForEntityClassification, LukeForEntityPairClassification, LukeForEntitySpanClassification, + LukeForQuestionAnswering, + LukeForSequenceClassification, + LukeForTokenClassification, + LukeForMultipleChoice, ) if is_torch_available() else () @@ -436,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): test_head_masking = True def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + entity_inputs_dict = {k: v for k, v in inputs_dict.items() if k.startswith("entity")} + inputs_dict = {k: v for k, v in inputs_dict.items() if not k.startswith("entity")} + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + if model_class == LukeForMultipleChoice: + entity_inputs_dict = { + k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous() + if v.ndim == 2 + else v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1, -1).contiguous() + for k, v in entity_inputs_dict.items() + } + inputs_dict.update(entity_inputs_dict) + if model_class == LukeForEntitySpanClassification: inputs_dict["entity_start_positions"] = torch.zeros( (self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device @@ -446,7 +630,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ) if return_labels: - if model_class in (LukeForEntityClassification, LukeForEntityPairClassification): + if model_class in ( + LukeForEntityClassification, + LukeForEntityPairClassification, + LukeForSequenceClassification, + LukeForMultipleChoice, + ): inputs_dict["labels"] = torch.zeros( self.model_tester.batch_size, dtype=torch.long, device=torch_device ) @@ -456,6 +645,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): dtype=torch.long, device=torch_device, ) + elif model_class == LukeForTokenClassification: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), + dtype=torch.long, + device=torch_device, + ) elif model_class == LukeForMaskedLM: inputs_dict["labels"] = torch.zeros( (self.model_tester.batch_size, self.model_tester.seq_length), @@ -496,6 +691,22 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:]))) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + def test_for_entity_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)