added type hints for BART model (#16270)
* added type hints for BART model * make fixup, adding imports to copied files * Adding some missing types to cookiecutter * Adding some missing types to cookiecutter * Adding some missing types to cookiecutter Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -17,7 +17,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -297,11 +297,11 @@ class BartEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.FloatTensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.FloatTensor,
|
||||||
layer_head_mask: torch.Tensor,
|
layer_head_mask: torch.FloatTensor,
|
||||||
output_attentions: bool = False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
@@ -384,7 +384,7 @@ class BartDecoderLayer(nn.Module):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
):
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
@@ -478,7 +478,7 @@ class BartClassificationHead(nn.Module):
|
|||||||
self.dropout = nn.Dropout(p=pooler_dropout)
|
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||||
self.out_proj = nn.Linear(inner_dim, num_classes)
|
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = torch.tanh(hidden_states)
|
hidden_states = torch.tanh(hidden_states)
|
||||||
@@ -728,14 +728,14 @@ class BartEncoder(BartPretrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
@@ -917,19 +917,19 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
@@ -1172,22 +1172,22 @@ class BartModel(BartPretrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqModelOutput]:
|
||||||
|
|
||||||
# different to other models, Bart automatically creates decoder_input_ids from
|
# different to other models, Bart automatically creates decoder_input_ids from
|
||||||
# input_ids if no decoder_input_ids are provided
|
# input_ids if no decoder_input_ids are provided
|
||||||
@@ -1306,23 +1306,23 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
@add_end_docstrings(BART_GENERATION_EXAMPLE)
|
@add_end_docstrings(BART_GENERATION_EXAMPLE)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
@@ -1454,22 +1454,22 @@ class BartForSequenceClassification(BartPretrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
@@ -1580,23 +1580,23 @@ class BartForQuestionAnswering(BartPretrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.Tensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
start_positions=None,
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
end_positions=None,
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
@@ -1721,20 +1721,20 @@ class BartForCausalLM(BartPretrainedModel):
|
|||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
@@ -39,6 +40,7 @@ from ...modeling_tf_outputs import (
|
|||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
TFCausalLanguageModelingLoss,
|
TFCausalLanguageModelingLoss,
|
||||||
|
TFModelInputType,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFWrappedEmbeddings,
|
TFWrappedEmbeddings,
|
||||||
@@ -170,7 +172,7 @@ class TFBartAttention(tf.keras.layers.Layer):
|
|||||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
@@ -297,7 +299,13 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
|||||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||||
|
|
||||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
def call(
|
||||||
|
self,
|
||||||
|
hidden_states: tf.Tensor,
|
||||||
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
|
||||||
|
layer_head_mask: Optional[tf.Tensor],
|
||||||
|
training: Optional[bool] = False,
|
||||||
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
@@ -365,14 +373,14 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -663,16 +671,16 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[TFModelInputType] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
||||||
@@ -813,21 +821,21 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[TFModelInputType] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
||||||
@@ -1030,24 +1038,24 @@ class TFBartMainLayer(tf.keras.layers.Layer):
|
|||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[TFModelInputType] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
|
||||||
|
|
||||||
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||||
use_cache = False
|
use_cache = False
|
||||||
@@ -1143,24 +1151,24 @@ class TFBartModel(TFBartPretrainedModel):
|
|||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[TFModelInputType] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
||||||
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -1248,25 +1256,25 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[TFModelInputType] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
labels=None,
|
labels: Optional[tf.Tensor] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -1573,7 +1573,7 @@ class BigBirdPegasusClassificationHead(nn.Module):
|
|||||||
self.dropout = nn.Dropout(p=pooler_dropout)
|
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||||
self.out_proj = nn.Linear(inner_dim, num_classes)
|
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = torch.tanh(hidden_states)
|
hidden_states = torch.tanh(hidden_states)
|
||||||
@@ -2367,22 +2367,22 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqModelOutput]:
|
||||||
|
|
||||||
# different to other models, BigBirdPegasus automatically creates decoder_input_ids from
|
# different to other models, BigBirdPegasus automatically creates decoder_input_ids from
|
||||||
# input_ids if no decoder_input_ids are provided
|
# input_ids if no decoder_input_ids are provided
|
||||||
@@ -2503,23 +2503,23 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
|||||||
@add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE)
|
@add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
@@ -2652,22 +2652,22 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
@@ -2779,23 +2779,23 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.Tensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
start_positions=None,
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
end_positions=None,
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -1440,20 +1440,20 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
|||||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -288,11 +288,11 @@ class BlenderbotSmallEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.FloatTensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.FloatTensor,
|
||||||
layer_head_mask: torch.Tensor,
|
layer_head_mask: torch.FloatTensor,
|
||||||
output_attentions: bool = False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
@@ -376,7 +376,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
):
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
@@ -1411,20 +1411,20 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
@@ -172,7 +173,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
|||||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
@@ -300,7 +301,13 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
|
|||||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||||
|
|
||||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
def call(
|
||||||
|
self,
|
||||||
|
hidden_states: tf.Tensor,
|
||||||
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
|
||||||
|
layer_head_mask: Optional[tf.Tensor],
|
||||||
|
training: Optional[bool] = False,
|
||||||
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
@@ -369,14 +376,14 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -754,7 +754,7 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
|||||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -305,11 +305,11 @@ class MarianEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.FloatTensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.FloatTensor,
|
||||||
layer_head_mask: torch.Tensor,
|
layer_head_mask: torch.FloatTensor,
|
||||||
output_attentions: bool = False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
@@ -393,7 +393,7 @@ class MarianDecoderLayer(nn.Module):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
):
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
@@ -1573,20 +1573,20 @@ class MarianForCausalLM(MarianPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
|||||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
@@ -340,7 +340,13 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
|
|||||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||||
|
|
||||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
def call(
|
||||||
|
self,
|
||||||
|
hidden_states: tf.Tensor,
|
||||||
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
|
||||||
|
layer_head_mask: Optional[tf.Tensor],
|
||||||
|
training: Optional[bool] = False,
|
||||||
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
@@ -409,14 +415,14 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: tf.Tensor,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -485,7 +485,7 @@ class MBartClassificationHead(nn.Module):
|
|||||||
self.dropout = nn.Dropout(p=pooler_dropout)
|
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||||
self.out_proj = nn.Linear(inner_dim, num_classes)
|
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = torch.tanh(hidden_states)
|
hidden_states = torch.tanh(hidden_states)
|
||||||
@@ -1445,22 +1445,22 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
|||||||
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
|
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
@@ -1572,23 +1572,23 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
|||||||
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
|
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.Tensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
start_positions=None,
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
end_positions=None,
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
@@ -1715,20 +1715,20 @@ class MBartForCausalLM(MBartPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
|||||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -1543,20 +1543,20 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
|
|||||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large
|
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
|||||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -302,11 +302,11 @@ class PLBartEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.FloatTensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.FloatTensor,
|
||||||
layer_head_mask: torch.Tensor,
|
layer_head_mask: torch.FloatTensor,
|
||||||
output_attentions: bool = False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||||
@@ -390,7 +390,7 @@ class PLBartDecoderLayer(nn.Module):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
):
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
@@ -485,7 +485,7 @@ class PLBartClassificationHead(nn.Module):
|
|||||||
self.dropout = nn.Dropout(p=pooler_dropout)
|
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||||
self.out_proj = nn.Linear(inner_dim, num_classes)
|
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = torch.tanh(hidden_states)
|
hidden_states = torch.tanh(hidden_states)
|
||||||
@@ -699,14 +699,14 @@ class PLBartEncoder(PLBartPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
@@ -889,19 +889,19 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
@@ -1416,22 +1416,22 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
|
|||||||
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
|
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
@@ -1562,20 +1562,20 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
|||||||
@@ -274,7 +274,7 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
|||||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
|||||||
@@ -783,7 +783,7 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
|||||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||||
attention_mask: Optional[tf.Tensor] = None,
|
attention_mask: Optional[tf.Tensor] = None,
|
||||||
layer_head_mask: Optional[tf.Tensor] = None,
|
layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
training=False,
|
training: Optional[bool] = False,
|
||||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|||||||
@@ -1571,7 +1571,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
|
|||||||
import math
|
import math
|
||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|||||||
Reference in New Issue
Block a user