Activation function managed from the config file
This commit is contained in:
@@ -16,7 +16,7 @@ class AlbertConfig(PretrainedConfig):
|
|||||||
intermediate_size=16384,
|
intermediate_size=16384,
|
||||||
inner_group_num=1,
|
inner_group_num=1,
|
||||||
down_scale_factor=1,
|
down_scale_factor=1,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu_new",
|
||||||
hidden_dropout_prob=0,
|
hidden_dropout_prob=0,
|
||||||
attention_probs_dropout_prob=0,
|
attention_probs_dropout_prob=0,
|
||||||
max_position_embeddings=512,
|
max_position_embeddings=512,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.configuration_albert import AlbertConfig
|
from transformers.configuration_albert import AlbertConfig
|
||||||
from transformers.modeling_bert import BertEmbeddings, BertModel, BertSelfAttention, prune_linear_layer, gelu_new
|
from transformers.modeling_bert import BertEmbeddings, BertModel, BertSelfAttention, prune_linear_layer, ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
|
||||||
@@ -190,11 +190,12 @@ class AlbertLayer(nn.Module):
|
|||||||
self.attention = AlbertAttention(config)
|
self.attention = AlbertAttention(config)
|
||||||
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
self.activation = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||||
attention_output = self.attention(hidden_states, attention_mask)
|
attention_output = self.attention(hidden_states, attention_mask)
|
||||||
ffn_output = self.ffn(attention_output)
|
ffn_output = self.ffn(attention_output)
|
||||||
ffn_output = gelu_new(ffn_output)
|
ffn_output = self.activation(ffn_output)
|
||||||
ffn_output = self.ffn_output(ffn_output)
|
ffn_output = self.ffn_output(ffn_output)
|
||||||
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
|
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
|
||||||
|
|
||||||
@@ -392,6 +393,7 @@ class AlbertForMaskedLM(PreTrainedModel):
|
|||||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
||||||
self.word_embeddings = nn.Linear(config.embedding_size, config.vocab_size)
|
self.word_embeddings = nn.Linear(config.embedding_size, config.vocab_size)
|
||||||
|
self.activation = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
""" Make sure we are sharing the input and output embeddings.
|
""" Make sure we are sharing the input and output embeddings.
|
||||||
@@ -405,7 +407,7 @@ class AlbertForMaskedLM(PreTrainedModel):
|
|||||||
outputs = self.bert(input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None)
|
outputs = self.bert(input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None)
|
||||||
sequence_outputs = outputs[0]
|
sequence_outputs = outputs[0]
|
||||||
hidden_states = self.dense(sequence_outputs)
|
hidden_states = self.dense(sequence_outputs)
|
||||||
hidden_states = gelu_new(hidden_states)
|
hidden_states = self.activation(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
prediction_scores = self.word_embeddings(hidden_states)
|
prediction_scores = self.word_embeddings(hidden_states)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user