rename BertLayer to BertEncoderLayer
This commit is contained in:
@@ -315,9 +315,9 @@ class BertOutput(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class BertLayer(nn.Module):
|
class BertEncoderLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertLayer, self).__init__()
|
super(BertEncoderLayer, self).__init__()
|
||||||
self.attention = BertAttention(config)
|
self.attention = BertAttention(config)
|
||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
@@ -336,7 +336,7 @@ class BertEncoder(nn.Module):
|
|||||||
super(BertEncoder, self).__init__()
|
super(BertEncoder, self).__init__()
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
|
|||||||
Reference in New Issue
Block a user