Handles multi layer and multi groups
This commit is contained in:
@@ -136,7 +136,6 @@ class AlbertModel(BertModel):
|
|||||||
head_mask=head_mask)
|
head_mask=head_mask)
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
|
|
||||||
print(sequence_output.shape, sequence_output[:, 0].shape, self.pooler(sequence_output[:, 0]).shape)
|
|
||||||
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
|
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
|
||||||
|
|
||||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||||
@@ -260,7 +259,6 @@ class AlbertLayer(nn.Module):
|
|||||||
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||||
for _ in range(self.config.inner_group_num):
|
|
||||||
attention_output = self.attention(hidden_states, attention_mask)[0]
|
attention_output = self.attention(hidden_states, attention_mask)[0]
|
||||||
ffn_output = self.ffn(attention_output)
|
ffn_output = self.ffn(attention_output)
|
||||||
ffn_output = gelu_new(ffn_output)
|
ffn_output = gelu_new(ffn_output)
|
||||||
@@ -303,16 +301,16 @@ class AlbertTransformer(nn.Module):
|
|||||||
return (hidden_states,)
|
return (hidden_states,)
|
||||||
|
|
||||||
|
|
||||||
model_size = 'base'
|
# model_size = 'base'
|
||||||
hidden_groups = 1
|
# hidden_groups = 1
|
||||||
inner_groups = 1
|
# inner_groups = 2
|
||||||
config = AlbertConfig.from_json_file("/home/hf/google-research/albert/config_{}-{}-hg-{}-ig.json".format(model_size, hidden_groups, inner_groups))
|
# config = AlbertConfig.from_json_file("/home/hf/google-research/albert/config_{}-{}-hg-{}-ig.json".format(model_size, hidden_groups, inner_groups))
|
||||||
model = AlbertModel(config)
|
# model = AlbertModel(config)
|
||||||
|
|
||||||
print(model)
|
# # print(model)
|
||||||
model = load_tf_weights_in_albert(model, config, "/home/hf/transformers/albert-{}-{}-hg-{}-ig/albert-{}-{}-hg-{}-ig".format(model_size, hidden_groups, inner_groups, model_size, hidden_groups, inner_groups))
|
# model = load_tf_weights_in_albert(model, config, "/home/hf/transformers/albert-{}-{}-hg-{}-ig/albert-{}-{}-hg-{}-ig".format(model_size, hidden_groups, inner_groups, model_size, hidden_groups, inner_groups))
|
||||||
model.eval()
|
# # model.eval()
|
||||||
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
|
# # print(sum(p.numel() for p in model.parameters() if p.requires_grad))
|
||||||
|
|
||||||
|
|
||||||
# input_ids = [[31, 51, 99, 88, 54, 34, 23, 23, 12], [15, 5, 0, 88, 54, 34, 23, 23, 12]]
|
# input_ids = [[31, 51, 99, 88, 54, 34, 23, 23, 12], [15, 5, 0, 88, 54, 34, 23, 23, 12]]
|
||||||
|
|||||||
Reference in New Issue
Block a user