Adding fine-tuning models to LUKE (#18353)
* add LUKE models for downstream tasks * add new LUKE models to docs * fix typos * remove commented lines * exclude None items from tuple return values
This commit is contained in:
@@ -152,3 +152,23 @@ This model was contributed by [ikuyamada](https://huggingface.co/ikuyamada) and
|
||||
|
||||
[[autodoc]] LukeForEntitySpanClassification
|
||||
- forward
|
||||
|
||||
## LukeForSequenceClassification
|
||||
|
||||
[[autodoc]] LukeForSequenceClassification
|
||||
- forward
|
||||
|
||||
## LukeForMultipleChoice
|
||||
|
||||
[[autodoc]] LukeForMultipleChoice
|
||||
- forward
|
||||
|
||||
## LukeForTokenClassification
|
||||
|
||||
[[autodoc]] LukeForTokenClassification
|
||||
- forward
|
||||
|
||||
## LukeForQuestionAnswering
|
||||
|
||||
[[autodoc]] LukeForQuestionAnswering
|
||||
- forward
|
||||
|
||||
@@ -1363,6 +1363,10 @@ else:
|
||||
"LukeForEntityClassification",
|
||||
"LukeForEntityPairClassification",
|
||||
"LukeForEntitySpanClassification",
|
||||
"LukeForMultipleChoice",
|
||||
"LukeForQuestionAnswering",
|
||||
"LukeForSequenceClassification",
|
||||
"LukeForTokenClassification",
|
||||
"LukeForMaskedLM",
|
||||
"LukeModel",
|
||||
"LukePreTrainedModel",
|
||||
@@ -3953,6 +3957,10 @@ if TYPE_CHECKING:
|
||||
LukeForEntityPairClassification,
|
||||
LukeForEntitySpanClassification,
|
||||
LukeForMaskedLM,
|
||||
LukeForMultipleChoice,
|
||||
LukeForQuestionAnswering,
|
||||
LukeForSequenceClassification,
|
||||
LukeForTokenClassification,
|
||||
LukeModel,
|
||||
LukePreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -170,6 +170,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("ibert", "IBertForMaskedLM"),
|
||||
("layoutlm", "LayoutLMForMaskedLM"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("luke", "LukeForMaskedLM"),
|
||||
("lxmert", "LxmertForPreTraining"),
|
||||
("megatron-bert", "MegatronBertForPreTraining"),
|
||||
("mobilebert", "MobileBertForPreTraining"),
|
||||
@@ -230,6 +231,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("led", "LEDForConditionalGeneration"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("longt5", "LongT5ForConditionalGeneration"),
|
||||
("luke", "LukeForMaskedLM"),
|
||||
("m2m_100", "M2M100ForConditionalGeneration"),
|
||||
("marian", "MarianMTModel"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
@@ -499,6 +501,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
|
||||
("led", "LEDForSequenceClassification"),
|
||||
("longformer", "LongformerForSequenceClassification"),
|
||||
("luke", "LukeForSequenceClassification"),
|
||||
("mbart", "MBartForSequenceClassification"),
|
||||
("megatron-bert", "MegatronBertForSequenceClassification"),
|
||||
("mobilebert", "MobileBertForSequenceClassification"),
|
||||
@@ -551,6 +554,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
|
||||
("led", "LEDForQuestionAnswering"),
|
||||
("longformer", "LongformerForQuestionAnswering"),
|
||||
("luke", "LukeForQuestionAnswering"),
|
||||
("lxmert", "LxmertForQuestionAnswering"),
|
||||
("mbart", "MBartForQuestionAnswering"),
|
||||
("megatron-bert", "MegatronBertForQuestionAnswering"),
|
||||
@@ -611,6 +615,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
||||
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
|
||||
("longformer", "LongformerForTokenClassification"),
|
||||
("luke", "LukeForTokenClassification"),
|
||||
("megatron-bert", "MegatronBertForTokenClassification"),
|
||||
("mobilebert", "MobileBertForTokenClassification"),
|
||||
("mpnet", "MPNetForTokenClassification"),
|
||||
@@ -647,6 +652,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
("funnel", "FunnelForMultipleChoice"),
|
||||
("ibert", "IBertForMultipleChoice"),
|
||||
("longformer", "LongformerForMultipleChoice"),
|
||||
("luke", "LukeForMultipleChoice"),
|
||||
("megatron-bert", "MegatronBertForMultipleChoice"),
|
||||
("mobilebert", "MobileBertForMultipleChoice"),
|
||||
("mpnet", "MPNetForMultipleChoice"),
|
||||
|
||||
@@ -37,6 +37,10 @@ else:
|
||||
"LukeForEntityClassification",
|
||||
"LukeForEntityPairClassification",
|
||||
"LukeForEntitySpanClassification",
|
||||
"LukeForMultipleChoice",
|
||||
"LukeForQuestionAnswering",
|
||||
"LukeForSequenceClassification",
|
||||
"LukeForTokenClassification",
|
||||
"LukeForMaskedLM",
|
||||
"LukeModel",
|
||||
"LukePreTrainedModel",
|
||||
@@ -59,6 +63,10 @@ if TYPE_CHECKING:
|
||||
LukeForEntityPairClassification,
|
||||
LukeForEntitySpanClassification,
|
||||
LukeForMaskedLM,
|
||||
LukeForMultipleChoice,
|
||||
LukeForQuestionAnswering,
|
||||
LukeForSequenceClassification,
|
||||
LukeForTokenClassification,
|
||||
LukeModel,
|
||||
LukePreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -74,6 +74,8 @@ class LukeConfig(PretrainedConfig):
|
||||
Whether or not the model should use the entity-aware self-attention mechanism proposed in [LUKE: Deep
|
||||
Contextualized Entity Representations with Entity-aware Self-attention (Yamada et
|
||||
al.)](https://arxiv.org/abs/2010.01057).
|
||||
classifier_dropout (`float`, *optional*):
|
||||
The dropout ratio for the classification head.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -108,6 +110,7 @@ class LukeConfig(PretrainedConfig):
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
use_entity_aware_attention=True,
|
||||
classifier_dropout=None,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
@@ -131,3 +134,4 @@ class LukeConfig(PretrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.use_entity_aware_attention = use_entity_aware_attention
|
||||
self.classifier_dropout = classifier_dropout
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
@@ -28,6 +29,7 @@ from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
@@ -247,6 +249,147 @@ class EntitySpanClassificationOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LukeSequenceClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Outputs of sentence classification models.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
|
||||
layer plus the initial entity embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LukeTokenClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of token classification models.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
|
||||
Classification loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
|
||||
layer plus the initial entity embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LukeQuestionAnsweringModelOutput(ModelOutput):
|
||||
"""
|
||||
Outputs of question answering models.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||
Span-start scores (before SoftMax).
|
||||
end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||
Span-end scores (before SoftMax).
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
|
||||
layer plus the initial entity embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
start_logits: torch.FloatTensor = None
|
||||
end_logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LukeMultipleChoiceModelOutput(ModelOutput):
|
||||
"""
|
||||
Outputs of multiple choice models.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
|
||||
Classification loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
|
||||
*num_choices* is the second dimension of the input tensors. (see *input_ids* above).
|
||||
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
|
||||
layer plus the initial entity embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class LukeEmbeddings(nn.Module):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
@@ -1240,15 +1383,20 @@ class LukeForMaskedLM(LukePreTrainedModel):
|
||||
loss = loss + mep_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions)
|
||||
if mlm_loss is not None and mep_loss is not None:
|
||||
return (loss, mlm_loss, mep_loss) + output
|
||||
elif mlm_loss is not None:
|
||||
return (loss, mlm_loss) + output
|
||||
elif mep_loss is not None:
|
||||
return (loss, mep_loss) + output
|
||||
else:
|
||||
return output
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
loss,
|
||||
mlm_loss,
|
||||
mep_loss,
|
||||
logits,
|
||||
entity_logits,
|
||||
outputs.hidden_states,
|
||||
outputs.entity_hidden_states,
|
||||
outputs.attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return LukeMaskedLMOutput(
|
||||
loss=loss,
|
||||
@@ -1360,13 +1508,11 @@ class LukeForEntityClassification(LukePreTrainedModel):
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
|
||||
|
||||
if not return_dict:
|
||||
output = (
|
||||
logits,
|
||||
outputs.hidden_states,
|
||||
outputs.entity_hidden_states,
|
||||
outputs.attentions,
|
||||
return tuple(
|
||||
v
|
||||
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
|
||||
if v is not None
|
||||
)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return EntityClassificationOutput(
|
||||
loss=loss,
|
||||
@@ -1480,13 +1626,11 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
|
||||
|
||||
if not return_dict:
|
||||
output = (
|
||||
logits,
|
||||
outputs.hidden_states,
|
||||
outputs.entity_hidden_states,
|
||||
outputs.attentions,
|
||||
return tuple(
|
||||
v
|
||||
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
|
||||
if v is not None
|
||||
)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return EntityPairClassificationOutput(
|
||||
loss=loss,
|
||||
@@ -1620,13 +1764,11 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
|
||||
|
||||
if not return_dict:
|
||||
output = (
|
||||
logits,
|
||||
outputs.hidden_states,
|
||||
outputs.entity_hidden_states,
|
||||
outputs.attentions,
|
||||
return tuple(
|
||||
v
|
||||
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
|
||||
if v is not None
|
||||
)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return EntitySpanClassificationOutput(
|
||||
loss=loss,
|
||||
@@ -1635,3 +1777,460 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
|
||||
entity_hidden_states=outputs.entity_hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The LUKE Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
||||
pooled output) e.g. for GLUE tasks.
|
||||
""",
|
||||
LUKE_START_DOCSTRING,
|
||||
)
|
||||
class LukeForSequenceClassification(LukePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.luke = LukeModel(config)
|
||||
self.dropout = nn.Dropout(
|
||||
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
||||
)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=LukeSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
entity_ids: Optional[torch.LongTensor] = None,
|
||||
entity_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
entity_token_type_ids: Optional[torch.LongTensor] = None,
|
||||
entity_position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, LukeSequenceClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.luke(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = outputs.pooler_output
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return LukeSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
entity_hidden_states=outputs.entity_hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The LUKE Model with a token classification head on top (a linear layer on top of the hidden-states output). To
|
||||
solve Named-Entity Recognition (NER) task using LUKE, `LukeForEntitySpanClassification` is more suitable than this
|
||||
class.
|
||||
""",
|
||||
LUKE_START_DOCSTRING,
|
||||
)
|
||||
class LukeForTokenClassification(LukePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.luke = LukeModel(config, add_pooling_layer=False)
|
||||
self.dropout = nn.Dropout(
|
||||
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
||||
)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=LukeTokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
entity_ids: Optional[torch.LongTensor] = None,
|
||||
entity_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
entity_token_type_ids: Optional[torch.LongTensor] = None,
|
||||
entity_position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, LukeTokenClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
||||
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
||||
`input_ids` above)
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.luke(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return LukeTokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
entity_hidden_states=outputs.entity_hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The LUKE Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
||||
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
LUKE_START_DOCSTRING,
|
||||
)
|
||||
class LukeForQuestionAnswering(LukePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.luke = LukeModel(config, add_pooling_layer=False)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=LukeQuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.FloatTensor] = None,
|
||||
entity_ids: Optional[torch.LongTensor] = None,
|
||||
entity_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
entity_token_type_ids: Optional[torch.LongTensor] = None,
|
||||
entity_position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, LukeQuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.luke(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
total_loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
total_loss,
|
||||
start_logits,
|
||||
end_logits,
|
||||
outputs.hidden_states,
|
||||
outputs.entity_hidden_states,
|
||||
outputs.attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return LukeQuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
entity_hidden_states=outputs.entity_hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The LUKE Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
||||
softmax) e.g. for RocStories/SWAG tasks.
|
||||
""",
|
||||
LUKE_START_DOCSTRING,
|
||||
)
|
||||
class LukeForMultipleChoice(LukePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.luke = LukeModel(config)
|
||||
self.dropout = nn.Dropout(
|
||||
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
||||
)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=LukeMultipleChoiceModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
entity_ids: Optional[torch.LongTensor] = None,
|
||||
entity_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
entity_token_type_ids: Optional[torch.LongTensor] = None,
|
||||
entity_position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, LukeMultipleChoiceModelOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
||||
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
||||
`input_ids` above)
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
inputs_embeds = (
|
||||
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
entity_ids = entity_ids.view(-1, entity_ids.size(-1)) if entity_ids is not None else None
|
||||
entity_attention_mask = (
|
||||
entity_attention_mask.view(-1, entity_attention_mask.size(-1))
|
||||
if entity_attention_mask is not None
|
||||
else None
|
||||
)
|
||||
entity_token_type_ids = (
|
||||
entity_token_type_ids.view(-1, entity_token_type_ids.size(-1))
|
||||
if entity_token_type_ids is not None
|
||||
else None
|
||||
)
|
||||
entity_position_ids = (
|
||||
entity_position_ids.view(-1, entity_position_ids.size(-2), entity_position_ids.size(-1))
|
||||
if entity_position_ids is not None
|
||||
else None
|
||||
)
|
||||
|
||||
outputs = self.luke(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = outputs.pooler_output
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
loss,
|
||||
reshaped_logits,
|
||||
outputs.hidden_states,
|
||||
outputs.entity_hidden_states,
|
||||
outputs.attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return LukeMultipleChoiceModelOutput(
|
||||
loss=loss,
|
||||
logits=reshaped_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
entity_hidden_states=outputs.entity_hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -2736,6 +2736,34 @@ class LukeForMaskedLM(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LukeForMultipleChoice(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LukeForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LukeForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LukeForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LukeModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -30,6 +30,10 @@ if is_torch_available():
|
||||
LukeForEntityPairClassification,
|
||||
LukeForEntitySpanClassification,
|
||||
LukeForMaskedLM,
|
||||
LukeForMultipleChoice,
|
||||
LukeForQuestionAnswering,
|
||||
LukeForSequenceClassification,
|
||||
LukeForTokenClassification,
|
||||
LukeModel,
|
||||
LukeTokenizer,
|
||||
)
|
||||
@@ -66,6 +70,8 @@ class LukeModelTester:
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
num_entity_classification_labels=9,
|
||||
num_entity_pair_classification_labels=6,
|
||||
num_entity_span_classification_labels=4,
|
||||
@@ -99,6 +105,8 @@ class LukeModelTester:
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.num_entity_classification_labels = num_entity_classification_labels
|
||||
self.num_entity_pair_classification_labels = num_entity_pair_classification_labels
|
||||
self.num_entity_span_classification_labels = num_entity_span_classification_labels
|
||||
@@ -139,7 +147,8 @@ class LukeModelTester:
|
||||
)
|
||||
|
||||
sequence_labels = None
|
||||
labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
entity_labels = None
|
||||
entity_classification_labels = None
|
||||
entity_pair_classification_labels = None
|
||||
@@ -147,7 +156,9 @@ class LukeModelTester:
|
||||
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
|
||||
|
||||
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
|
||||
@@ -170,7 +181,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -207,7 +219,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -247,7 +260,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -266,7 +280,7 @@ class LukeModelTester:
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
labels=labels,
|
||||
labels=token_labels,
|
||||
entity_labels=entity_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
@@ -288,7 +302,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -322,7 +337,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -356,7 +372,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -386,6 +403,156 @@ class LukeModelTester:
|
||||
result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels)
|
||||
)
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
entity_ids,
|
||||
entity_attention_mask,
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
):
|
||||
model = LukeForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
entity_ids,
|
||||
entity_attention_mask,
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = LukeForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
labels=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
entity_ids,
|
||||
entity_attention_mask,
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = LukeForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
labels=token_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
entity_ids,
|
||||
entity_attention_mask,
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = LukeForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_attention_mask = attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_entity_ids = entity_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_entity_token_type_ids = (
|
||||
entity_token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
)
|
||||
multiple_choice_entity_attention_mask = (
|
||||
entity_attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
)
|
||||
multiple_choice_entity_position_ids = (
|
||||
entity_position_ids.unsqueeze(1).expand(-1, self.num_choices, -1, -1).contiguous()
|
||||
)
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_attention_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
entity_ids=multiple_choice_entity_ids,
|
||||
entity_attention_mask=multiple_choice_entity_attention_mask,
|
||||
entity_token_type_ids=multiple_choice_entity_token_type_ids,
|
||||
entity_position_ids=multiple_choice_entity_position_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -398,7 +565,8 @@ class LukeModelTester:
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
@@ -426,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
LukeForEntityClassification,
|
||||
LukeForEntityPairClassification,
|
||||
LukeForEntitySpanClassification,
|
||||
LukeForQuestionAnswering,
|
||||
LukeForSequenceClassification,
|
||||
LukeForTokenClassification,
|
||||
LukeForMultipleChoice,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
@@ -436,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_head_masking = True
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
entity_inputs_dict = {k: v for k, v in inputs_dict.items() if k.startswith("entity")}
|
||||
inputs_dict = {k: v for k, v in inputs_dict.items() if not k.startswith("entity")}
|
||||
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
if model_class == LukeForMultipleChoice:
|
||||
entity_inputs_dict = {
|
||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||
if v.ndim == 2
|
||||
else v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1, -1).contiguous()
|
||||
for k, v in entity_inputs_dict.items()
|
||||
}
|
||||
inputs_dict.update(entity_inputs_dict)
|
||||
|
||||
if model_class == LukeForEntitySpanClassification:
|
||||
inputs_dict["entity_start_positions"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device
|
||||
@@ -446,7 +630,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
if return_labels:
|
||||
if model_class in (LukeForEntityClassification, LukeForEntityPairClassification):
|
||||
if model_class in (
|
||||
LukeForEntityClassification,
|
||||
LukeForEntityPairClassification,
|
||||
LukeForSequenceClassification,
|
||||
LukeForMultipleChoice,
|
||||
):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
@@ -456,6 +645,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
elif model_class == LukeForTokenClassification:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
elif model_class == LukeForMaskedLM:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length),
|
||||
@@ -496,6 +691,22 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_entity_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user