Add TF Funnel Transformer (#7029)
* Add TF Funnel Transformer * Proper dummy input * Formatting * Update src/transformers/modeling_tf_funnel.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * One review comment forgotten Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -592,6 +592,17 @@ if is_tf_available():
|
||||
TFFlaubertModel,
|
||||
TFFlaubertWithLMHeadModel,
|
||||
)
|
||||
from .modeling_tf_funnel import (
|
||||
TF_FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFFunnelBaseModel,
|
||||
TFFunnelForMaskedLM,
|
||||
TFFunnelForMultipleChoice,
|
||||
TFFunnelForPreTraining,
|
||||
TFFunnelForQuestionAnswering,
|
||||
TFFunnelForSequenceClassification,
|
||||
TFFunnelForTokenClassification,
|
||||
TFFunnelModel,
|
||||
)
|
||||
from .modeling_tf_gpt2 import (
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFGPT2DoubleHeadsModel,
|
||||
|
||||
@@ -133,7 +133,9 @@ CONFIG_NAME = "config.json"
|
||||
MODEL_CARD_NAME = "modelcard.json"
|
||||
|
||||
|
||||
MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
|
||||
MULTIPLE_CHOICE_DUMMY_INPUTS = [
|
||||
[[0, 1, 0, 1], [1, 0, 0, 1]]
|
||||
] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too.
|
||||
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
|
||||
|
||||
|
||||
@@ -425,9 +425,9 @@ def _relative_shift_gather(positional_attn, context_len, shift):
|
||||
# max_rel_len = 2 * context_len + shift -1 is the numbers of possible relative positions i-j
|
||||
|
||||
# What's next is the same as doing the following gather, which might be clearer code but less efficient.
|
||||
# idxs = context_len + torch.arange(0, context_len).unsqueeze(0) - torch.arange(0, context_len).unsqueeze(1)
|
||||
# idxs = context_len + torch.arange(0, context_len).unsqueeze(0) - torch.arange(0, seq_len).unsqueeze(1)
|
||||
# # matrix of context_len + i-j
|
||||
# return positional_attn.gather(3, idxs.expand([bs, n_head, context_len, context_len]))
|
||||
# return positional_attn.gather(3, idxs.expand([batch_size, n_head, context_len, context_len]))
|
||||
|
||||
positional_attn = torch.reshape(positional_attn, [batch_size, n_head, max_rel_len, seq_len])
|
||||
positional_attn = positional_attn[:, :, shift:, :]
|
||||
@@ -526,9 +526,9 @@ class FunnelRelMultiheadAttention(nn.Module):
|
||||
token_type_attn *= cls_mask
|
||||
return token_type_attn
|
||||
|
||||
def forward(self, query, key, value, attention_inputs, head_mask=None, output_attentions=False):
|
||||
# q has shape batch_size x seq_len x d_model
|
||||
# k and v have shapes batch_size x context_len x d_model
|
||||
def forward(self, query, key, value, attention_inputs, output_attentions=False):
|
||||
# query has shape batch_size x seq_len x d_model
|
||||
# key and value have shapes batch_size x context_len x d_model
|
||||
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
|
||||
|
||||
batch_size, seq_len, _ = query.shape
|
||||
@@ -598,8 +598,8 @@ class FunnelLayer(nn.Module):
|
||||
self.attention = FunnelRelMultiheadAttention(config, block_index)
|
||||
self.ffn = FunnelPositionwiseFFN(config)
|
||||
|
||||
def forward(self, q, k, v, attention_inputs, output_attentions=False):
|
||||
attn = self.attention(q, k, v, attention_inputs, output_attentions=output_attentions)
|
||||
def forward(self, query, key, value, attention_inputs, output_attentions=False):
|
||||
attn = self.attention(query, key, value, attention_inputs, output_attentions=output_attentions)
|
||||
output = self.ffn(attn[0])
|
||||
return (output, attn[1]) if output_attentions else (output,)
|
||||
|
||||
@@ -792,7 +792,7 @@ class FunnelClassificationHead(nn.Module):
|
||||
|
||||
def forward(self, hidden):
|
||||
hidden = self.linear_hidden(hidden)
|
||||
hidden = F.tanh(hidden)
|
||||
hidden = torch.tanh(hidden)
|
||||
hidden = self.dropout(hidden)
|
||||
return self.linear_out(hidden)
|
||||
|
||||
@@ -954,7 +954,7 @@ class FunnelBaseModel(FunnelPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare base Funnel Transformer Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
"The bare Funnel Transformer Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
FUNNEL_START_DOCSTRING,
|
||||
)
|
||||
class FunnelModel(FunnelPreTrainedModel):
|
||||
@@ -1099,10 +1099,10 @@ class FunnelForPreTraining(FunnelPreTrainedModel):
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = FunnelTokenizer.from_pretrained('funnel-transformer/small')
|
||||
>>> model = FunnelForPreTraining.from_pretrained('funnel-transformer/small')
|
||||
>>> model = FunnelForPreTraining.from_pretrained('funnel-transformer/small', return_dict=True)
|
||||
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
>>> logits = model(input_ids).logits
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors= "pt")
|
||||
>>> logits = model(**inputs).logits
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from .configuration_auto import (
|
||||
DistilBertConfig,
|
||||
ElectraConfig,
|
||||
FlaubertConfig,
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
LongformerConfig,
|
||||
MobileBertConfig,
|
||||
@@ -92,6 +93,15 @@ from .modeling_tf_flaubert import (
|
||||
TFFlaubertModel,
|
||||
TFFlaubertWithLMHeadModel,
|
||||
)
|
||||
from .modeling_tf_funnel import (
|
||||
TFFunnelForMaskedLM,
|
||||
TFFunnelForMultipleChoice,
|
||||
TFFunnelForPreTraining,
|
||||
TFFunnelForQuestionAnswering,
|
||||
TFFunnelForSequenceClassification,
|
||||
TFFunnelForTokenClassification,
|
||||
TFFunnelModel,
|
||||
)
|
||||
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
||||
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
|
||||
from .modeling_tf_mobilebert import (
|
||||
@@ -163,6 +173,7 @@ TF_MODEL_MAPPING = OrderedDict(
|
||||
(XLMConfig, TFXLMModel),
|
||||
(CTRLConfig, TFCTRLModel),
|
||||
(ElectraConfig, TFElectraModel),
|
||||
(FunnelConfig, TFFunnelModel),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -184,6 +195,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
(XLMConfig, TFXLMWithLMHeadModel),
|
||||
(CTRLConfig, TFCTRLLMHeadModel),
|
||||
(ElectraConfig, TFElectraForPreTraining),
|
||||
(FunnelConfig, TFFunnelForPreTraining),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -206,6 +218,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
(XLMConfig, TFXLMWithLMHeadModel),
|
||||
(CTRLConfig, TFCTRLLMHeadModel),
|
||||
(ElectraConfig, TFElectraForMaskedLM),
|
||||
(FunnelConfig, TFFunnelForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -237,6 +250,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
||||
(XLMConfig, TFXLMWithLMHeadModel),
|
||||
(ElectraConfig, TFElectraForMaskedLM),
|
||||
(FunnelConfig, TFFunnelForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -255,6 +269,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(FlaubertConfig, TFFlaubertForSequenceClassification),
|
||||
(XLMConfig, TFXLMForSequenceClassification),
|
||||
(ElectraConfig, TFElectraForSequenceClassification),
|
||||
(FunnelConfig, TFFunnelForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -272,6 +287,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
|
||||
(XLMConfig, TFXLMForQuestionAnsweringSimple),
|
||||
(ElectraConfig, TFElectraForQuestionAnswering),
|
||||
(FunnelConfig, TFFunnelForQuestionAnswering),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -288,6 +304,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(MobileBertConfig, TFMobileBertForTokenClassification),
|
||||
(XLNetConfig, TFXLNetForTokenClassification),
|
||||
(ElectraConfig, TFElectraForTokenClassification),
|
||||
(FunnelConfig, TFFunnelForTokenClassification),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -304,6 +321,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
(FlaubertConfig, TFFlaubertForMultipleChoice),
|
||||
(AlbertConfig, TFAlbertForMultipleChoice),
|
||||
(ElectraConfig, TFElectraForMultipleChoice),
|
||||
(FunnelConfig, TFFunnelForMultipleChoice),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
1663
src/transformers/modeling_tf_funnel.py
Normal file
1663
src/transformers/modeling_tf_funnel.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user