Tf longformer for sequence classification (#8231)
* working on LongformerForSequenceClassification * add TFLongformerForMultipleChoice * add TFLongformerForTokenClassification * use add_start_docstrings_to_model_forward * test TFLongformerForSequenceClassification * test TFLongformerForMultipleChoice * test TFLongformerForTokenClassification * remove test from repo * add test and doc for TFLongformerForSequenceClassification, TFLongformerForTokenClassification, TFLongformerForMultipleChoice * add requested classes to modeling_tf_auto.py update dummy_tf_objects fix tests fix bugs in requested classes * pass all tests except test_inputs_embeds * sync with master * pass all tests except test_inputs_embeds * pass all tests * pass all tests * work on test_inputs_embeds * fix style and quality * make multi choice work * fix TFLongformerForTokenClassification signature * fix TFLongformerForMultipleChoice, TFLongformerForSequenceClassification signature * fix mult choice * fix mc hint * fix input embeds * fix input embeds * refactor input embeds * fix copy issue * apply sylvains changes and clean more Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -99,21 +99,41 @@ Longformer specific outputs
|
||||
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerBaseModelOutputWithPooling
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMultipleChoiceModelOutput
|
||||
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMaskedLMOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerQuestionAnsweringModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerSequenceClassifierOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMultipleChoiceModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerTokenClassifierOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutputWithPooling
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerMaskedLMOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerQuestionAnsweringModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerSequenceClassifierOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerMultipleChoiceModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerTokenClassifierOutput
|
||||
:members:
|
||||
|
||||
LongformerModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -177,3 +197,24 @@ TFLongformerForQuestionAnswering
|
||||
.. autoclass:: transformers.TFLongformerForQuestionAnswering
|
||||
:members: call
|
||||
|
||||
|
||||
TFLongformerForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFLongformerForSequenceClassification
|
||||
:members: call
|
||||
|
||||
|
||||
TFLongformerForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFLongformerForTokenClassification
|
||||
:members: call
|
||||
|
||||
|
||||
TFLongformerForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFLongformerForMultipleChoice
|
||||
:members: call
|
||||
|
||||
|
||||
@@ -766,7 +766,10 @@ if is_tf_available():
|
||||
from .models.longformer import (
|
||||
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForMultipleChoice,
|
||||
TFLongformerForQuestionAnswering,
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForTokenClassification,
|
||||
TFLongformerModel,
|
||||
TFLongformerSelfAttention,
|
||||
)
|
||||
|
||||
@@ -92,7 +92,10 @@ from ..funnel.modeling_tf_funnel import (
|
||||
from ..gpt2.modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
||||
from ..longformer.modeling_tf_longformer import (
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForMultipleChoice,
|
||||
TFLongformerForQuestionAnswering,
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForTokenClassification,
|
||||
TFLongformerModel,
|
||||
)
|
||||
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
|
||||
@@ -314,6 +317,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(AlbertConfig, TFAlbertForSequenceClassification),
|
||||
(CamembertConfig, TFCamembertForSequenceClassification),
|
||||
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
|
||||
(LongformerConfig, TFLongformerForSequenceClassification),
|
||||
(RobertaConfig, TFRobertaForSequenceClassification),
|
||||
(BertConfig, TFBertForSequenceClassification),
|
||||
(XLNetConfig, TFXLNetForSequenceClassification),
|
||||
@@ -353,6 +357,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(FlaubertConfig, TFFlaubertForTokenClassification),
|
||||
(XLMConfig, TFXLMForTokenClassification),
|
||||
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
|
||||
(LongformerConfig, TFLongformerForTokenClassification),
|
||||
(RobertaConfig, TFRobertaForTokenClassification),
|
||||
(BertConfig, TFBertForTokenClassification),
|
||||
(MobileBertConfig, TFMobileBertForTokenClassification),
|
||||
@@ -368,6 +373,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
(CamembertConfig, TFCamembertForMultipleChoice),
|
||||
(XLMConfig, TFXLMForMultipleChoice),
|
||||
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
|
||||
(LongformerConfig, TFLongformerForMultipleChoice),
|
||||
(RobertaConfig, TFRobertaForMultipleChoice),
|
||||
(BertConfig, TFBertForMultipleChoice),
|
||||
(DistilBertConfig, TFDistilBertForMultipleChoice),
|
||||
|
||||
@@ -26,7 +26,10 @@ if is_tf_available():
|
||||
from .modeling_tf_longformer import (
|
||||
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForMultipleChoice,
|
||||
TFLongformerForQuestionAnswering,
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForTokenClassification,
|
||||
TFLongformerModel,
|
||||
TFLongformerSelfAttention,
|
||||
)
|
||||
|
||||
@@ -31,7 +31,6 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
||||
from ...modeling_utils import (
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
@@ -151,17 +150,15 @@ class LongformerBaseModelOutputWithPooling(ModelOutput):
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongformerMultipleChoiceModelOutput(ModelOutput):
|
||||
class LongformerMaskedLMOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of multiple choice Longformer models.
|
||||
Base class for masked language models outputs.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
||||
|
||||
Classification scores (before SoftMax).
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Masked language modeling (MLM) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@@ -249,6 +246,149 @@ class LongformerQuestionAnsweringModelOutput(ModelOutput):
|
||||
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongformerSequenceClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of sentence classification models.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
|
||||
mask.
|
||||
|
||||
Local attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token in the sequence to every token with
|
||||
global attention (first ``x`` values) and to every token in the attention window (remaining
|
||||
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
|
||||
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
|
||||
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
|
||||
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
|
||||
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
|
||||
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
|
||||
attention weights. If a token has global attention, the attention weights to all other tokens in
|
||||
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
|
||||
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongformerMultipleChoiceModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of multiple choice Longformer models.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
||||
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
|
||||
mask.
|
||||
|
||||
Local attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token in the sequence to every token with
|
||||
global attention (first ``x`` values) and to every token in the attention window (remaining
|
||||
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
|
||||
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
|
||||
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
|
||||
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
|
||||
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
|
||||
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
|
||||
attention weights. If a token has global attention, the attention weights to all other tokens in
|
||||
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
|
||||
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongformerTokenClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of token classification models.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
|
||||
Classification loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
|
||||
mask.
|
||||
|
||||
Local attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token in the sequence to every token with
|
||||
global attention (first ``x`` values) and to every token in the attention window (remaining
|
||||
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
|
||||
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
|
||||
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
|
||||
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
|
||||
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
|
||||
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
|
||||
attention weights. If a token has global attention, the attention weights to all other tokens in
|
||||
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
|
||||
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
def _get_question_end_index(input_ids, sep_token_id):
|
||||
"""
|
||||
Computes the index of the first occurance of `sep_token_id`.
|
||||
@@ -1495,7 +1635,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
||||
return self.lm_head.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -1561,7 +1701,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
return LongformerMaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@@ -1593,7 +1733,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="allenai/longformer-base-4096",
|
||||
output_type=SequenceClassifierOutput,
|
||||
output_type=LongformerSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
@@ -1651,7 +1791,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
return LongformerSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@@ -1837,7 +1977,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="allenai/longformer-base-4096",
|
||||
output_type=TokenClassifierOutput,
|
||||
output_type=LongformerTokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
@@ -1895,7 +2035,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
return LongformerTokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
|
||||
@@ -19,19 +19,21 @@ from typing import Optional, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.activations_tf import get_tf_activation
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFMaskedLMOutput, TFQuestionAnsweringModelOutput
|
||||
from ...modeling_tf_utils import (
|
||||
TFMaskedLanguageModelingLoss,
|
||||
TFMultipleChoiceLoss,
|
||||
TFPreTrainedModel,
|
||||
TFQuestionAnsweringLoss,
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
@@ -147,6 +149,52 @@ class TFLongformerBaseModelOutputWithPooling(ModelOutput):
|
||||
global_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFLongformerMaskedLMOutput(ModelOutput):
|
||||
"""
|
||||
Base class for masked language models outputs.
|
||||
|
||||
Args:
|
||||
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Masked language modeling (MLM) loss.
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x +
|
||||
attention_window + 1)`, where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Local attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token in the sequence to every token with
|
||||
global attention (first ``x`` values) and to every token in the attention window (remaining
|
||||
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
|
||||
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
|
||||
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
|
||||
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
|
||||
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
|
||||
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
|
||||
attention weights. If a token has global attention, the attention weights to all other tokens in
|
||||
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
|
||||
global_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x)`,
|
||||
where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
"""
|
||||
|
||||
loss: Optional[tf.Tensor] = None
|
||||
logits: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
global_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFLongformerQuestionAnsweringModelOutput(ModelOutput):
|
||||
"""
|
||||
@@ -196,6 +244,146 @@ class TFLongformerQuestionAnsweringModelOutput(ModelOutput):
|
||||
global_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFLongformerSequenceClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of sentence classification models.
|
||||
|
||||
Args:
|
||||
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x +
|
||||
attention_window + 1)`, where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Local attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token in the sequence to every token with
|
||||
global attention (first ``x`` values) and to every token in the attention window (remaining
|
||||
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
|
||||
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
|
||||
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
|
||||
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
|
||||
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
|
||||
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
|
||||
attention weights. If a token has global attention, the attention weights to all other tokens in
|
||||
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
|
||||
global_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x)`,
|
||||
where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
"""
|
||||
|
||||
loss: Optional[tf.Tensor] = None
|
||||
logits: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
global_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFLongformerMultipleChoiceModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of multiple choice models.
|
||||
|
||||
Args:
|
||||
loss (:obj:`tf.Tensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification loss.
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
||||
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x +
|
||||
attention_window + 1)`, where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Local attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token in the sequence to every token with
|
||||
global attention (first ``x`` values) and to every token in the attention window (remaining
|
||||
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
|
||||
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
|
||||
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
|
||||
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
|
||||
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
|
||||
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
|
||||
attention weights. If a token has global attention, the attention weights to all other tokens in
|
||||
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
|
||||
global_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x)`,
|
||||
where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
"""
|
||||
|
||||
loss: Optional[tf.Tensor] = None
|
||||
logits: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
global_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFLongformerTokenClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of token classification models.
|
||||
|
||||
Args:
|
||||
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
|
||||
Classification loss.
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x +
|
||||
attention_window + 1)`, where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Local attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token in the sequence to every token with
|
||||
global attention (first ``x`` values) and to every token in the attention window (remaining
|
||||
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
|
||||
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
|
||||
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
|
||||
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
|
||||
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
|
||||
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
|
||||
attention weights. If a token has global attention, the attention weights to all other tokens in
|
||||
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
|
||||
global_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x)`,
|
||||
where ``x`` is the number of tokens with global attention mask.
|
||||
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
"""
|
||||
|
||||
loss: Optional[tf.Tensor] = None
|
||||
logits: tf.Tensor = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
global_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True):
|
||||
"""
|
||||
Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is
|
||||
@@ -249,18 +437,17 @@ class TFLongformerLMHead(tf.keras.layers.Layer):
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, features):
|
||||
x = self.dense(features)
|
||||
x = self.act(x)
|
||||
x = self.layer_norm(x)
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
# project back to size of vocabulary with bias
|
||||
x = self.decoder(x, mode="linear") + self.bias
|
||||
hidden_states = self.decoder(hidden_states, mode="linear") + self.bias
|
||||
|
||||
return x
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaEmbeddings
|
||||
class TFLongformerEmbeddings(tf.keras.layers.Layer):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
@@ -304,17 +491,23 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def create_position_ids_from_input_ids(self, x):
|
||||
def create_position_ids_from_input_ids(self, input_ids):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
|
||||
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
x: tf.Tensor
|
||||
input_ids: tf.Tensor
|
||||
|
||||
Returns: tf.Tensor
|
||||
"""
|
||||
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32)
|
||||
input_ids_shape = shape_list(input_ids)
|
||||
|
||||
# multiple choice has 3 dimensions
|
||||
if len(input_ids_shape) == 3:
|
||||
input_ids = tf.reshape(input_ids, (input_ids_shape[0] * input_ids_shape[1], input_ids_shape[2]))
|
||||
|
||||
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=tf.int32)
|
||||
incremental_indices = tf.math.cumsum(mask, axis=1) * mask
|
||||
|
||||
return incremental_indices + self.padding_idx
|
||||
@@ -1783,7 +1976,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="allenai/longformer-base-4096",
|
||||
output_type=TFMaskedLMOutput,
|
||||
output_type=TFLongformerMaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
@@ -1837,11 +2030,12 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFMaskedLMOutput(
|
||||
return TFLongformerMaskedLMOutput(
|
||||
loss=loss,
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
global_attentions=outputs.global_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -1871,7 +2065,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="allenai/longformer-large-4096-finetuned-triviaqa",
|
||||
output_type=TFQuestionAnsweringModelOutput,
|
||||
output_type=TFLongformerQuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
@@ -1969,3 +2163,357 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
attentions=outputs.attentions,
|
||||
global_attentions=outputs.global_attentions,
|
||||
)
|
||||
|
||||
|
||||
class TFLongformerClassificationHead(tf.keras.layers.Layer):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
activation="tanh",
|
||||
name="dense",
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.out_proj = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
|
||||
)
|
||||
|
||||
def call(self, hidden_states, training=False):
|
||||
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
output = self.out_proj(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Longformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
||||
pooled output) e.g. for GLUE tasks.
|
||||
""",
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||
self.classifier = TFLongformerClassificationHead(config, name="classifier")
|
||||
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="allenai/longformer-base-4096",
|
||||
output_type=TFLongformerSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
global_attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
global_attention_mask = inputs[4] if len(inputs) > 4 else global_attention_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
labels = inputs.get("labels", labels)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if global_attention_mask is None and input_ids is not None:
|
||||
logger.info("Initializing global attention on CLS token...")
|
||||
# global attention on cls token
|
||||
global_attention_mask = tf.zeros_like(input_ids)
|
||||
global_attention_mask = tf.tensor_scatter_nd_update(
|
||||
global_attention_mask,
|
||||
[[i, 0] for i in range(input_ids.shape[0])],
|
||||
[1 for _ in range(input_ids.shape[0])],
|
||||
)
|
||||
|
||||
outputs = self.longformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFLongformerSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
global_attentions=outputs.global_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Longformer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
|
||||
a softmax) e.g. for RocStories/SWAG tasks.
|
||||
""",
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoiceLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
input_ids = tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)
|
||||
# make sure global layers are initialized
|
||||
global_attention_mask = tf.constant([[[0, 0, 0, 1], [0, 0, 0, 1]]] * 2)
|
||||
return {"input_ids": input_ids, "global_attention_mask": global_attention_mask}
|
||||
|
||||
@add_start_docstrings_to_model_forward(
|
||||
LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
||||
)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="allenai/longformer-base-4096",
|
||||
output_type=TFLongformerMultipleChoiceModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
global_attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
global_attention_mask = inputs[4] if len(inputs) > 4 else global_attention_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
labels = inputs.get("labels", labels)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_global_attention_mask = (
|
||||
tf.reshape(global_attention_mask, (-1, global_attention_mask.shape[-1]))
|
||||
if global_attention_mask is not None
|
||||
else None
|
||||
)
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
outputs = self.longformer(
|
||||
flat_input_ids,
|
||||
position_ids=flat_position_ids,
|
||||
token_type_ids=flat_token_type_ids,
|
||||
attention_mask=flat_attention_mask,
|
||||
global_attention_mask=flat_global_attention_mask,
|
||||
inputs_embeds=flat_inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFLongformerMultipleChoiceModelOutput(
|
||||
loss=loss,
|
||||
logits=reshaped_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
global_attentions=outputs.global_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Longformer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
|
||||
for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.longformer = TFLongformerMainLayer(config=config, name="longformer")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="allenai/longformer-base-4096",
|
||||
output_type=TFLongformerTokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
global_attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.longformer(
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFLongformerTokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
global_attentions=outputs.global_attentions,
|
||||
)
|
||||
|
||||
@@ -751,15 +751,15 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, features):
|
||||
x = self.dense(features)
|
||||
x = self.act(x)
|
||||
x = self.layer_norm(x)
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
# project back to size of vocabulary with bias
|
||||
x = self.decoder(x, mode="linear") + self.bias
|
||||
hidden_states = self.decoder(hidden_states, mode="linear") + self.bias
|
||||
|
||||
return x
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||
|
||||
@@ -812,6 +812,15 @@ class TFLongformerForMaskedLM:
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFLongformerForMultipleChoice:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFLongformerForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
@@ -821,6 +830,24 @@ class TFLongformerForQuestionAnswering:
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFLongformerForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFLongformerForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFLongformerModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@@ -129,7 +129,7 @@ class LongformerModelTester:
|
||||
output_without_mask = model(input_ids)["last_hidden_state"]
|
||||
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
|
||||
|
||||
def create_and_check_longformer_model(
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerModel(config=config)
|
||||
@@ -141,7 +141,7 @@ class LongformerModelTester:
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_longformer_model_with_global_attention_mask(
|
||||
def create_and_check_model_with_global_attention_mask(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerModel(config=config)
|
||||
@@ -163,7 +163,7 @@ class LongformerModelTester:
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_longformer_for_masked_lm(
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerForMaskedLM(config=config)
|
||||
@@ -172,7 +172,7 @@ class LongformerModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_longformer_for_question_answering(
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerForQuestionAnswering(config=config)
|
||||
@@ -189,7 +189,7 @@ class LongformerModelTester:
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_longformer_for_sequence_classification(
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
@@ -199,7 +199,7 @@ class LongformerModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_longformer_for_token_classification(
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
@@ -209,7 +209,7 @@ class LongformerModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_longformer_for_multiple_choice(
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
@@ -296,37 +296,37 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_longformer_model(self):
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_longformer_model_attention_mask_determinism(self):
|
||||
def test_model_attention_mask_determinism(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
|
||||
|
||||
def test_longformer_model_global_attention_mask(self):
|
||||
def test_model_global_attention_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
|
||||
self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
|
||||
|
||||
def test_longformer_for_masked_lm(self):
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_longformer_for_question_answering(self):
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
|
||||
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_sequence_classification(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -691,7 +691,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
) # long input
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
loss, prediction_scores = model(input_ids, labels=input_ids)
|
||||
loss, prediction_scores = model(input_ids, labels=input_ids).to_tuple()
|
||||
|
||||
expected_loss = torch.tensor(0.0074, device=torch_device)
|
||||
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
|
||||
|
||||
@@ -29,7 +29,10 @@ if is_tf_available():
|
||||
from transformers import (
|
||||
LongformerConfig,
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForMultipleChoice,
|
||||
TFLongformerForQuestionAnswering,
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForTokenClassification,
|
||||
TFLongformerModel,
|
||||
TFLongformerSelfAttention,
|
||||
)
|
||||
@@ -130,7 +133,7 @@ class TFLongformerModelTester:
|
||||
output_without_mask = model(input_ids)[0]
|
||||
tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4)
|
||||
|
||||
def create_and_check_longformer_model(
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
@@ -144,7 +147,7 @@ class TFLongformerModelTester:
|
||||
)
|
||||
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_longformer_model_with_global_attention_mask(
|
||||
def create_and_check_model_with_global_attention_mask(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
@@ -172,7 +175,7 @@ class TFLongformerModelTester:
|
||||
)
|
||||
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_longformer_for_masked_lm(
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
@@ -180,7 +183,7 @@ class TFLongformerModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
def create_and_check_longformer_for_question_answering(
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
@@ -196,6 +199,41 @@ class TFLongformerModelTester:
|
||||
self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFLongformerForSequenceClassification(config=config)
|
||||
output = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
).logits
|
||||
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.num_labels])
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFLongformerForTokenClassification(config=config)
|
||||
output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels).logits
|
||||
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.seq_length, self.num_labels])
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = TFLongformerForMultipleChoice(config=config)
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
output = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
global_attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
).logits
|
||||
self.parent.assertListEqual(list(output.shape), [self.batch_size, self.num_choices])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -252,6 +290,9 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
TFLongformerModel,
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForQuestionAnswering,
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForMultipleChoice,
|
||||
TFLongformerForTokenClassification,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
@@ -264,25 +305,37 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_longformer_model_attention_mask_determinism(self):
|
||||
def test_model_attention_mask_determinism(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
|
||||
|
||||
def test_longformer_model(self):
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_longformer_model_global_attention_mask(self):
|
||||
def test_model_global_attention_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
|
||||
self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
|
||||
|
||||
def test_longformer_for_masked_lm(self):
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_longformer_for_question_answering(self):
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
|
||||
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
|
||||
Reference in New Issue
Block a user