added type hints for BART model (#16270)

* added type hints for BART model

* make fixup, adding imports to copied files

* Adding some missing types to cookiecutter

* Adding some missing types to cookiecutter

* Adding some missing types to cookiecutter

Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
Robot Jelly
2022-03-21 20:48:01 +05:30
committed by GitHub
parent 460f36d352
commit d50f62f2de
19 changed files with 482 additions and 461 deletions

View File

@@ -17,7 +17,7 @@ import copy
import math import math
import random import random
import warnings import warnings
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -297,11 +297,11 @@ class BartEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.FloatTensor,
attention_mask: torch.Tensor, attention_mask: torch.FloatTensor,
layer_head_mask: torch.Tensor, layer_head_mask: torch.FloatTensor,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
@@ -384,7 +384,7 @@ class BartDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -478,7 +478,7 @@ class BartClassificationHead(nn.Module):
self.dropout = nn.Dropout(p=pooler_dropout) self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes) self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states) hidden_states = torch.tanh(hidden_states)
@@ -728,14 +728,14 @@ class BartEncoder(BartPretrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutput]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -917,19 +917,19 @@ class BartDecoder(BartPretrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1172,22 +1172,22 @@ class BartModel(BartPretrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqModelOutput]:
# different to other models, Bart automatically creates decoder_input_ids from # different to other models, Bart automatically creates decoder_input_ids from
# input_ids if no decoder_input_ids are provided # input_ids if no decoder_input_ids are provided
@@ -1306,23 +1306,23 @@ class BartForConditionalGeneration(BartPretrainedModel):
@add_end_docstrings(BART_GENERATION_EXAMPLE) @add_end_docstrings(BART_GENERATION_EXAMPLE)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -1454,22 +1454,22 @@ class BartForSequenceClassification(BartPretrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1580,23 +1580,23 @@ class BartForQuestionAnswering(BartPretrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.Tensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
@@ -1721,20 +1721,20 @@ class BartForCausalLM(BartPretrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

View File

@@ -18,6 +18,7 @@
import random import random
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
@@ -39,6 +40,7 @@ from ...modeling_tf_outputs import (
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
@@ -170,7 +172,7 @@ class TFBartAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
@@ -297,7 +299,13 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
layer_head_mask: Optional[tf.Tensor],
training: Optional[bool] = False,
) -> tf.Tensor:
""" """
Args: Args:
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
@@ -365,14 +373,14 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
hidden_states, hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
""" """
Args: Args:
@@ -663,16 +671,16 @@ class TFBartEncoder(tf.keras.layers.Layer):
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
""" """
Args: Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
@@ -813,21 +821,21 @@ class TFBartDecoder(tf.keras.layers.Layer):
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
r""" r"""
Args: Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
@@ -1030,24 +1038,24 @@ class TFBartMainLayer(tf.keras.layers.Layer):
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs **kwargs
): ) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
if decoder_input_ids is None and decoder_inputs_embeds is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
use_cache = False use_cache = False
@@ -1143,24 +1151,24 @@ class TFBartModel(TFBartPretrainedModel):
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs **kwargs
): ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
@@ -1248,25 +1256,25 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

View File

@@ -18,7 +18,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@@ -1573,7 +1573,7 @@ class BigBirdPegasusClassificationHead(nn.Module):
self.dropout = nn.Dropout(p=pooler_dropout) self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes) self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states) hidden_states = torch.tanh(hidden_states)
@@ -2367,22 +2367,22 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqModelOutput]:
# different to other models, BigBirdPegasus automatically creates decoder_input_ids from # different to other models, BigBirdPegasus automatically creates decoder_input_ids from
# input_ids if no decoder_input_ids are provided # input_ids if no decoder_input_ids are provided
@@ -2503,23 +2503,23 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
@add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE) @add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -2652,22 +2652,22 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -2779,23 +2779,23 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.Tensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.

View File

@@ -20,7 +20,7 @@ import math
import os import os
import random import random
import warnings import warnings
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -1440,20 +1440,20 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

View File

@@ -173,7 +173,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""

View File

@@ -18,7 +18,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -288,11 +288,11 @@ class BlenderbotSmallEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.FloatTensor,
attention_mask: torch.Tensor, attention_mask: torch.FloatTensor,
layer_head_mask: torch.Tensor, layer_head_mask: torch.FloatTensor,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
@@ -376,7 +376,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -1411,20 +1411,20 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

View File

@@ -18,6 +18,7 @@
import random import random
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
@@ -172,7 +173,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
@@ -300,7 +301,13 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
layer_head_mask: Optional[tf.Tensor],
training: Optional[bool] = False,
) -> tf.Tensor:
""" """
Args: Args:
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
@@ -369,14 +376,14 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
hidden_states, hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
""" """
Args: Args:

View File

@@ -754,7 +754,7 @@ class TFHubertAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""

View File

@@ -18,7 +18,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@@ -305,11 +305,11 @@ class MarianEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.FloatTensor,
attention_mask: torch.Tensor, attention_mask: torch.FloatTensor,
layer_head_mask: torch.Tensor, layer_head_mask: torch.FloatTensor,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
@@ -393,7 +393,7 @@ class MarianDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -1573,20 +1573,20 @@ class MarianForCausalLM(MarianPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

View File

@@ -212,7 +212,7 @@ class TFMarianAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
@@ -340,7 +340,13 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
layer_head_mask: Optional[tf.Tensor],
training: Optional[bool] = False,
) -> tf.Tensor:
""" """
Args: Args:
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
@@ -409,14 +415,14 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
hidden_states, hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
""" """
Args: Args:

View File

@@ -16,7 +16,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -485,7 +485,7 @@ class MBartClassificationHead(nn.Module):
self.dropout = nn.Dropout(p=pooler_dropout) self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes) self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states) hidden_states = torch.tanh(hidden_states)
@@ -1445,22 +1445,22 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1572,23 +1572,23 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward # Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.Tensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
@@ -1715,20 +1715,20 @@ class MBartForCausalLM(MBartPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

View File

@@ -172,7 +172,7 @@ class TFMBartAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""

View File

@@ -17,7 +17,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@@ -1543,20 +1543,20 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large # Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

View File

@@ -213,7 +213,7 @@ class TFPegasusAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""

View File

@@ -16,7 +16,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -302,11 +302,11 @@ class PLBartEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.FloatTensor,
attention_mask: torch.Tensor, attention_mask: torch.FloatTensor,
layer_head_mask: torch.Tensor, layer_head_mask: torch.FloatTensor,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
@@ -390,7 +390,7 @@ class PLBartDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -485,7 +485,7 @@ class PLBartClassificationHead(nn.Module):
self.dropout = nn.Dropout(p=pooler_dropout) self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes) self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states) hidden_states = torch.tanh(hidden_states)
@@ -699,14 +699,14 @@ class PLBartEncoder(PLBartPreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutput]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -889,19 +889,19 @@ class PLBartDecoder(PLBartPreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1416,22 +1416,22 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1562,20 +1562,20 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

View File

@@ -274,7 +274,7 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""

View File

@@ -783,7 +783,7 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""

View File

@@ -25,7 +25,7 @@ import math
import os import os
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional, List
import datasets import datasets
from datasets import load_dataset from datasets import load_dataset

View File

@@ -1571,7 +1571,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
import math import math
import copy import copy
import random import random
from typing import Optional, Tuple from typing import Optional, Tuple, List, Union
import torch import torch
from torch import nn from torch import nn