Add type annotations for Rembert/Splinter and copies (#16338)
* undo black autoformat * minor fix to rembert forward with default * make fix-copies, make quality * Adding types to template model * Removing List from the template types * Remove `Optional` from a couple of types that don't accept `None` Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -257,14 +257,14 @@ class BertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -357,7 +357,7 @@ class BertSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -391,14 +391,14 @@ class BertAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -422,7 +422,7 @@ class BertIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -435,7 +435,7 @@ class BertOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -459,14 +459,14 @@ class BertLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
@@ -536,17 +536,17 @@ class BertEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states=False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
return_dict=True,
|
return_dict: Optional[bool] = True,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
@@ -630,7 +630,7 @@ class BertPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
@@ -649,7 +649,7 @@ class BertPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -681,7 +681,7 @@ class BertOnlyMLMHead(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = BertLMPredictionHead(config)
|
self.predictions = BertLMPredictionHead(config)
|
||||||
|
|
||||||
def forward(self, sequence_output):
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|||||||
@@ -1311,7 +1311,7 @@ class BigBirdSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -1412,7 +1412,7 @@ class BigBirdIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -1426,7 +1426,7 @@ class BigBirdOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -1684,7 +1684,7 @@ class BigBirdPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -1718,7 +1718,7 @@ class BigBirdOnlyMLMHead(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = BigBirdLMPredictionHead(config)
|
self.predictions = BigBirdLMPredictionHead(config)
|
||||||
|
|
||||||
def forward(self, sequence_output):
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|||||||
@@ -193,14 +193,14 @@ class Data2VecTextSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -294,7 +294,7 @@ class Data2VecTextSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -329,14 +329,14 @@ class Data2VecTextAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -361,7 +361,7 @@ class Data2VecTextIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -375,7 +375,7 @@ class Data2VecTextOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -400,14 +400,14 @@ class Data2VecTextLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
@@ -478,17 +478,17 @@ class Data2VecTextEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states=False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
return_dict=True,
|
return_dict: Optional[bool] = True,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
@@ -573,7 +573,7 @@ class Data2VecTextPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ class DebertaIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -301,7 +301,7 @@ class DebertaV2Intermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -250,14 +250,14 @@ class ElectraSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -351,7 +351,7 @@ class ElectraSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -386,14 +386,14 @@ class ElectraAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -418,7 +418,7 @@ class ElectraIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -432,7 +432,7 @@ class ElectraOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -457,14 +457,14 @@ class ElectraLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
@@ -535,17 +535,17 @@ class ElectraEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states=False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
return_dict=True,
|
return_dict: Optional[bool] = True,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ class FNetIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -245,7 +245,7 @@ class FNetOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -323,7 +323,7 @@ class FNetPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
@@ -343,7 +343,7 @@ class FNetPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
|
|||||||
@@ -166,14 +166,14 @@ class LayoutLMSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -267,7 +267,7 @@ class LayoutLMSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -302,14 +302,14 @@ class LayoutLMAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -334,7 +334,7 @@ class LayoutLMIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -348,7 +348,7 @@ class LayoutLMOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -373,14 +373,14 @@ class LayoutLMLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
@@ -451,17 +451,17 @@ class LayoutLMEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states=False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
return_dict=True,
|
return_dict: Optional[bool] = True,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
@@ -546,7 +546,7 @@ class LayoutLMPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
@@ -566,7 +566,7 @@ class LayoutLMPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -600,7 +600,7 @@ class LayoutLMOnlyMLMHead(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = LayoutLMLMPredictionHead(config)
|
self.predictions = LayoutLMLMPredictionHead(config)
|
||||||
|
|
||||||
def forward(self, sequence_output):
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|||||||
@@ -249,7 +249,7 @@ class LayoutLMv2Intermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -263,7 +263,7 @@ class LayoutLMv2Output(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
|||||||
@@ -1104,7 +1104,7 @@ class LongformerSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -1170,7 +1170,7 @@ class LongformerIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -1184,7 +1184,7 @@ class LongformerOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -1338,7 +1338,7 @@ class LongformerPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
|||||||
@@ -480,7 +480,7 @@ class LukeSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -544,7 +544,7 @@ class LukeIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -558,7 +558,7 @@ class LukeOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -708,7 +708,7 @@ class LukePooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
|||||||
@@ -228,14 +228,14 @@ class MegatronBertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -396,7 +396,7 @@ class MegatronBertIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -615,7 +615,7 @@ class MegatronBertPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
@@ -635,7 +635,7 @@ class MegatronBertPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -669,7 +669,7 @@ class MegatronBertOnlyMLMHead(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = MegatronBertLMPredictionHead(config)
|
self.predictions = MegatronBertLMPredictionHead(config)
|
||||||
|
|
||||||
def forward(self, sequence_output):
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|||||||
@@ -259,7 +259,7 @@ class MPNetIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -273,7 +273,7 @@ class MPNetOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -408,7 +408,7 @@ class MPNetPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ class NystromformerSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -301,7 +301,7 @@ class NystromformerIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -315,7 +315,7 @@ class NystromformerOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -417,7 +417,7 @@ class NystromformerPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -451,7 +451,7 @@ class NystromformerOnlyMLMHead(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = NystromformerLMPredictionHead(config)
|
self.predictions = NystromformerLMPredictionHead(config)
|
||||||
|
|
||||||
def forward(self, sequence_output):
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|||||||
@@ -641,7 +641,7 @@ class QDQBertPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
@@ -661,7 +661,7 @@ class QDQBertPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -265,14 +265,14 @@ class RealmSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -366,7 +366,7 @@ class RealmSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -401,14 +401,14 @@ class RealmAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -433,7 +433,7 @@ class RealmIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -447,7 +447,7 @@ class RealmOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -472,14 +472,14 @@ class RealmLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
@@ -550,17 +550,17 @@ class RealmEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states=False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
return_dict=True,
|
return_dict: Optional[bool] = True,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
@@ -645,7 +645,7 @@ class RealmPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -165,8 +166,13 @@ class RemBertEmbeddings(nn.Module):
|
|||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
self,
|
||||||
):
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
past_key_values_length: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
else:
|
else:
|
||||||
@@ -199,7 +205,7 @@ class RemBertPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
@@ -236,14 +242,14 @@ class RemBertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Tuple[Tuple[torch.FloatTensor]] = None,
|
||||||
output_attentions=False,
|
output_attentions: bool = False,
|
||||||
):
|
) -> Tuple:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -321,7 +327,7 @@ class RemBertSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -357,14 +363,14 @@ class RemBertAttention(nn.Module):
|
|||||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
|
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -389,7 +395,7 @@ class RemBertIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -403,7 +409,7 @@ class RemBertOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -428,14 +434,14 @@ class RemBertLayer(nn.Module):
|
|||||||
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
|
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
@@ -508,17 +514,18 @@ class RemBertEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states=False,
|
output_hidden_states: bool = False,
|
||||||
return_dict=True,
|
return_dict: bool = True,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
|
||||||
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
@@ -608,7 +615,7 @@ class RemBertPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -623,7 +630,7 @@ class RemBertLMPredictionHead(nn.Module):
|
|||||||
self.activation = ACT2FN[config.hidden_act]
|
self.activation = ACT2FN[config.hidden_act]
|
||||||
self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.activation(hidden_states)
|
hidden_states = self.activation(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -637,7 +644,7 @@ class RemBertOnlyMLMHead(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = RemBertLMPredictionHead(config)
|
self.predictions = RemBertLMPredictionHead(config)
|
||||||
|
|
||||||
def forward(self, sequence_output):
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
@@ -788,20 +795,20 @@ class RemBertModel(RemBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
@@ -941,19 +948,19 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, MaskedLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||||
@@ -1039,21 +1046,21 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
@@ -1186,17 +1193,17 @@ class RemBertForSequenceClassification(RemBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.FloatTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, SequenceClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
@@ -1283,17 +1290,17 @@ class RemBertForMultipleChoice(RemBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.FloatTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, MultipleChoiceModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
||||||
@@ -1376,17 +1383,17 @@ class RemBertForTokenClassification(RemBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.FloatTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, TokenClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
||||||
@@ -1455,18 +1462,18 @@ class RemBertForQuestionAnswering(RemBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: torch.FloatTensor = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
start_positions=None,
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
end_positions=None,
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
|||||||
@@ -193,14 +193,14 @@ class RobertaSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -294,7 +294,7 @@ class RobertaSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -329,14 +329,14 @@ class RobertaAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -361,7 +361,7 @@ class RobertaIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -375,7 +375,7 @@ class RobertaOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -400,14 +400,14 @@ class RobertaLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
@@ -478,17 +478,17 @@ class RobertaEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states=False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
return_dict=True,
|
return_dict: Optional[bool] = True,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
@@ -573,7 +573,7 @@ class RobertaPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
|||||||
@@ -359,7 +359,7 @@ class RoFormerSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -429,7 +429,7 @@ class RoFormerIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -443,7 +443,7 @@ class RoFormerOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -687,7 +687,7 @@ class RoFormerOnlyMLMHead(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = RoFormerLMPredictionHead(config)
|
self.predictions = RoFormerLMPredictionHead(config)
|
||||||
|
|
||||||
def forward(self, sequence_output):
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|||||||
@@ -924,7 +924,7 @@ class SEWDIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -69,8 +70,13 @@ class SplinterEmbeddings(nn.Module):
|
|||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
self,
|
||||||
):
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
past_key_values_length: Optional[int] = 0,
|
||||||
|
) -> Tuple:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
else:
|
else:
|
||||||
@@ -132,14 +138,14 @@ class SplinterSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -233,7 +239,7 @@ class SplinterSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -268,14 +274,14 @@ class SplinterAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -300,7 +306,7 @@ class SplinterIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -314,7 +320,7 @@ class SplinterOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -339,14 +345,14 @@ class SplinterLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
@@ -417,17 +423,17 @@ class SplinterEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states=False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
return_dict=True,
|
return_dict: Optional[bool] = True,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
@@ -643,20 +649,20 @@ class SplinterModel(SplinterPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
@@ -773,7 +779,7 @@ class SplinterFullyConnectedLayer(nn.Module):
|
|||||||
self.act_fn = ACT2FN[hidden_act]
|
self.act_fn = ACT2FN[hidden_act]
|
||||||
self.LayerNorm = nn.LayerNorm(self.output_dim)
|
self.LayerNorm = nn.LayerNorm(self.output_dim)
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(inputs)
|
hidden_states = self.dense(inputs)
|
||||||
hidden_states = self.act_fn(hidden_states)
|
hidden_states = self.act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -845,19 +851,19 @@ class SplinterForQuestionAnswering(SplinterPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
token_type_ids=None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
start_positions=None,
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
end_positions=None,
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
question_positions=None,
|
question_positions: Optional[torch.LongTensor] = None,
|
||||||
):
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
|||||||
@@ -449,7 +449,7 @@ class TapasSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -485,14 +485,14 @@ class TapasAttention(nn.Module):
|
|||||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
|
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -517,7 +517,7 @@ class TapasIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -531,7 +531,7 @@ class TapasOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -556,14 +556,14 @@ class TapasLayer(nn.Module):
|
|||||||
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
|
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
@@ -700,7 +700,7 @@ class TapasPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
@@ -720,7 +720,7 @@ class TapasPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -754,7 +754,7 @@ class TapasOnlyMLMHead(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = TapasLMPredictionHead(config)
|
self.predictions = TapasLMPredictionHead(config)
|
||||||
|
|
||||||
def forward(self, sequence_output):
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|||||||
@@ -273,7 +273,7 @@ class VisualBertSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -333,7 +333,7 @@ class VisualBertIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -347,7 +347,7 @@ class VisualBertOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -464,7 +464,7 @@ class VisualBertPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
@@ -484,7 +484,7 @@ class VisualBertPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
|
|||||||
@@ -187,14 +187,14 @@ class XLMRobertaXLSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@@ -354,7 +354,7 @@ class XLMRobertaXLIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -565,7 +565,7 @@ class XLMRobertaXLPooler(nn.Module):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
|||||||
@@ -445,7 +445,7 @@ class YosoSelfOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -494,7 +494,7 @@ class YosoIntermediate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -508,7 +508,7 @@ class YosoOutput(nn.Module):
|
|||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def forward(self, hidden_states, input_tensor):
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -610,7 +610,7 @@ class YosoPredictionHeadTransform(nn.Module):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -644,7 +644,7 @@ class YosoOnlyMLMHead(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.predictions = YosoLMPredictionHead(config)
|
self.predictions = YosoLMPredictionHead(config)
|
||||||
|
|
||||||
def forward(self, sequence_output):
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import torch.utils.checkpoint
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
|
|||||||
Reference in New Issue
Block a user