Add type annotations for Rembert/Splinter and copies (#16338)

* undo black autoformat

* minor fix to rembert forward with default

* make fix-copies, make quality

* Adding types to template model

* Removing List from the template types

* Remove `Optional` from a couple of types that don't accept `None`

Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
Jacob Dineen
2022-03-22 13:07:48 -07:00
committed by GitHub
parent c30798ec9d
commit ec3aace0ae
26 changed files with 541 additions and 527 deletions

View File

@@ -257,14 +257,14 @@ class BertSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@@ -357,7 +357,7 @@ class BertSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -391,14 +391,14 @@ class BertAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@@ -422,7 +422,7 @@ class BertIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -435,7 +435,7 @@ class BertOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -459,14 +459,14 @@ class BertLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@@ -536,17 +536,17 @@ class BertEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@@ -630,7 +630,7 @@ class BertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
@@ -649,7 +649,7 @@ class BertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
@@ -681,7 +681,7 @@ class BertOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = BertLMPredictionHead(config) self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores

View File

@@ -1311,7 +1311,7 @@ class BigBirdSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -1412,7 +1412,7 @@ class BigBirdIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -1426,7 +1426,7 @@ class BigBirdOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -1684,7 +1684,7 @@ class BigBirdPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
@@ -1718,7 +1718,7 @@ class BigBirdOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = BigBirdLMPredictionHead(config) self.predictions = BigBirdLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores

View File

@@ -193,14 +193,14 @@ class Data2VecTextSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@@ -294,7 +294,7 @@ class Data2VecTextSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -329,14 +329,14 @@ class Data2VecTextAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@@ -361,7 +361,7 @@ class Data2VecTextIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -375,7 +375,7 @@ class Data2VecTextOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -400,14 +400,14 @@ class Data2VecTextLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@@ -478,17 +478,17 @@ class Data2VecTextEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@@ -573,7 +573,7 @@ class Data2VecTextPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]

View File

@@ -313,7 +313,7 @@ class DebertaIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states

View File

@@ -301,7 +301,7 @@ class DebertaV2Intermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states

View File

@@ -250,14 +250,14 @@ class ElectraSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@@ -351,7 +351,7 @@ class ElectraSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -386,14 +386,14 @@ class ElectraAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@@ -418,7 +418,7 @@ class ElectraIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -432,7 +432,7 @@ class ElectraOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -457,14 +457,14 @@ class ElectraLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@@ -535,17 +535,17 @@ class ElectraEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

View File

@@ -231,7 +231,7 @@ class FNetIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -245,7 +245,7 @@ class FNetOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -323,7 +323,7 @@ class FNetPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
@@ -343,7 +343,7 @@ class FNetPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)

View File

@@ -166,14 +166,14 @@ class LayoutLMSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@@ -267,7 +267,7 @@ class LayoutLMSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -302,14 +302,14 @@ class LayoutLMAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@@ -334,7 +334,7 @@ class LayoutLMIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -348,7 +348,7 @@ class LayoutLMOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -373,14 +373,14 @@ class LayoutLMLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@@ -451,17 +451,17 @@ class LayoutLMEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@@ -546,7 +546,7 @@ class LayoutLMPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
@@ -566,7 +566,7 @@ class LayoutLMPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
@@ -600,7 +600,7 @@ class LayoutLMOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = LayoutLMLMPredictionHead(config) self.predictions = LayoutLMLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores

View File

@@ -249,7 +249,7 @@ class LayoutLMv2Intermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -263,7 +263,7 @@ class LayoutLMv2Output(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)

View File

@@ -1104,7 +1104,7 @@ class LongformerSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -1170,7 +1170,7 @@ class LongformerIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -1184,7 +1184,7 @@ class LongformerOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -1338,7 +1338,7 @@ class LongformerPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]

View File

@@ -480,7 +480,7 @@ class LukeSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -544,7 +544,7 @@ class LukeIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -558,7 +558,7 @@ class LukeOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -708,7 +708,7 @@ class LukePooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]

View File

@@ -228,14 +228,14 @@ class MegatronBertSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@@ -396,7 +396,7 @@ class MegatronBertIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -615,7 +615,7 @@ class MegatronBertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
@@ -635,7 +635,7 @@ class MegatronBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
@@ -669,7 +669,7 @@ class MegatronBertOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = MegatronBertLMPredictionHead(config) self.predictions = MegatronBertLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores

View File

@@ -259,7 +259,7 @@ class MPNetIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -273,7 +273,7 @@ class MPNetOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -408,7 +408,7 @@ class MPNetPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]

View File

@@ -252,7 +252,7 @@ class NystromformerSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -301,7 +301,7 @@ class NystromformerIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -315,7 +315,7 @@ class NystromformerOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -417,7 +417,7 @@ class NystromformerPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
@@ -451,7 +451,7 @@ class NystromformerOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = NystromformerLMPredictionHead(config) self.predictions = NystromformerLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores

View File

@@ -641,7 +641,7 @@ class QDQBertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
@@ -661,7 +661,7 @@ class QDQBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)

View File

@@ -17,7 +17,7 @@
import math import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
from packaging import version from packaging import version
@@ -265,14 +265,14 @@ class RealmSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@@ -366,7 +366,7 @@ class RealmSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -401,14 +401,14 @@ class RealmAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@@ -433,7 +433,7 @@ class RealmIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -447,7 +447,7 @@ class RealmOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -472,14 +472,14 @@ class RealmLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@@ -550,17 +550,17 @@ class RealmEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@@ -645,7 +645,7 @@ class RealmPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]

View File

@@ -17,6 +17,7 @@
import math import math
import os import os
from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -165,8 +166,13 @@ class RemBertEmbeddings(nn.Module):
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
): input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
@@ -199,7 +205,7 @@ class RemBertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
@@ -236,14 +242,14 @@ class RemBertSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Tuple[Tuple[torch.FloatTensor]] = None,
output_attentions=False, output_attentions: bool = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@@ -321,7 +327,7 @@ class RemBertSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -357,14 +363,14 @@ class RemBertAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward # Copied from transformers.models.bert.modeling_bert.BertAttention.forward
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@@ -389,7 +395,7 @@ class RemBertIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -403,7 +409,7 @@ class RemBertOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -428,14 +434,14 @@ class RemBertLayer(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward # Copied from transformers.models.bert.modeling_bert.BertLayer.forward
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@@ -508,17 +514,18 @@ class RemBertEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
hidden_states = self.embedding_hidden_mapping_in(hidden_states) hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
@@ -608,7 +615,7 @@ class RemBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
@@ -623,7 +630,7 @@ class RemBertLMPredictionHead(nn.Module):
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states) hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
@@ -637,7 +644,7 @@ class RemBertOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = RemBertLMPredictionHead(config) self.predictions = RemBertLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
@@ -788,20 +795,20 @@ class RemBertModel(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.LongTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -941,19 +948,19 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.LongTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MaskedLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1039,21 +1046,21 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.LongTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -1186,17 +1193,17 @@ class RemBertForSequenceClassification(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.FloatTensor = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1283,17 +1290,17 @@ class RemBertForMultipleChoice(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.FloatTensor = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MultipleChoiceModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
@@ -1376,17 +1383,17 @@ class RemBertForTokenClassification(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.FloatTensor = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, TokenClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@@ -1455,18 +1462,18 @@ class RemBertForQuestionAnswering(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.FloatTensor = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.

View File

@@ -193,14 +193,14 @@ class RobertaSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@@ -294,7 +294,7 @@ class RobertaSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -329,14 +329,14 @@ class RobertaAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@@ -361,7 +361,7 @@ class RobertaIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -375,7 +375,7 @@ class RobertaOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -400,14 +400,14 @@ class RobertaLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@@ -478,17 +478,17 @@ class RobertaEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@@ -573,7 +573,7 @@ class RobertaPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]

View File

@@ -359,7 +359,7 @@ class RoFormerSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -429,7 +429,7 @@ class RoFormerIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -443,7 +443,7 @@ class RoFormerOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -687,7 +687,7 @@ class RoFormerOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = RoFormerLMPredictionHead(config) self.predictions = RoFormerLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores

View File

@@ -924,7 +924,7 @@ class SEWDIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states

View File

@@ -16,6 +16,7 @@
import math import math
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -69,8 +70,13 @@ class SplinterEmbeddings(nn.Module):
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
): input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: Optional[int] = 0,
) -> Tuple:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
@@ -132,14 +138,14 @@ class SplinterSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@@ -233,7 +239,7 @@ class SplinterSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -268,14 +274,14 @@ class SplinterAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@@ -300,7 +306,7 @@ class SplinterIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -314,7 +320,7 @@ class SplinterOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -339,14 +345,14 @@ class SplinterLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@@ -417,17 +423,17 @@ class SplinterEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@@ -643,20 +649,20 @@ class SplinterModel(SplinterPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@@ -773,7 +779,7 @@ class SplinterFullyConnectedLayer(nn.Module):
self.act_fn = ACT2FN[hidden_act] self.act_fn = ACT2FN[hidden_act]
self.LayerNorm = nn.LayerNorm(self.output_dim) self.LayerNorm = nn.LayerNorm(self.output_dim)
def forward(self, inputs): def forward(self, inputs: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(inputs) hidden_states = self.dense(inputs)
hidden_states = self.act_fn(hidden_states) hidden_states = self.act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
@@ -845,19 +851,19 @@ class SplinterForQuestionAnswering(SplinterPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
question_positions=None, question_positions: Optional[torch.LongTensor] = None,
): ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.

View File

@@ -449,7 +449,7 @@ class TapasSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -485,14 +485,14 @@ class TapasAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward # Copied from transformers.models.bert.modeling_bert.BertAttention.forward
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@@ -517,7 +517,7 @@ class TapasIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -531,7 +531,7 @@ class TapasOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -556,14 +556,14 @@ class TapasLayer(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward # Copied from transformers.models.bert.modeling_bert.BertLayer.forward
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@@ -700,7 +700,7 @@ class TapasPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
@@ -720,7 +720,7 @@ class TapasPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
@@ -754,7 +754,7 @@ class TapasOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = TapasLMPredictionHead(config) self.predictions = TapasLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores

View File

@@ -273,7 +273,7 @@ class VisualBertSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -333,7 +333,7 @@ class VisualBertIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -347,7 +347,7 @@ class VisualBertOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -464,7 +464,7 @@ class VisualBertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
@@ -484,7 +484,7 @@ class VisualBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)

View File

@@ -187,14 +187,14 @@ class XLMRobertaXLSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@@ -354,7 +354,7 @@ class XLMRobertaXLIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -565,7 +565,7 @@ class XLMRobertaXLPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]

View File

@@ -445,7 +445,7 @@ class YosoSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -494,7 +494,7 @@ class YosoIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
@@ -508,7 +508,7 @@ class YosoOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -610,7 +610,7 @@ class YosoPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
@@ -644,7 +644,7 @@ class YosoOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = YosoLMPredictionHead(config) self.predictions = YosoLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores

View File

@@ -25,6 +25,7 @@ import torch.utils.checkpoint
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional, Tuple, Union
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (