Reorganize ALBERT conversion script
This commit is contained in:
@@ -68,14 +68,36 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|||||||
|
|
||||||
for name, array in zip(names, arrays):
|
for name, array in zip(names, arrays):
|
||||||
original_name = name
|
original_name = name
|
||||||
|
|
||||||
|
# If saved from the TF HUB module
|
||||||
|
name = name.replace("module/", "")
|
||||||
|
|
||||||
|
# Renaming and simplifying
|
||||||
name = name.replace("ffn_1", "ffn")
|
name = name.replace("ffn_1", "ffn")
|
||||||
name = name.replace("/bert/", "/albert/")
|
name = name.replace("bert/", "albert/")
|
||||||
name = name.replace("ffn/intermediate/output", "ffn_output")
|
|
||||||
name = name.replace("attention_1", "attention")
|
name = name.replace("attention_1", "attention")
|
||||||
name = name.replace("cls/predictions", "predictions")
|
|
||||||
name = name.replace("transform/", "")
|
name = name.replace("transform/", "")
|
||||||
name = name.replace("LayerNorm_1", "full_layer_layer_norm")
|
name = name.replace("LayerNorm_1", "full_layer_layer_norm")
|
||||||
name = name.replace("LayerNorm", "attention/LayerNorm")
|
name = name.replace("LayerNorm", "attention/LayerNorm")
|
||||||
|
name = name.replace("transformer/", "")
|
||||||
|
|
||||||
|
# The feed forward layer had an 'intermediate' step which has been abstracted away
|
||||||
|
name = name.replace("intermediate/dense/", "")
|
||||||
|
name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
|
||||||
|
|
||||||
|
# ALBERT attention was split between self and output which have been abstracted away
|
||||||
|
name = name.replace("/output/", "/")
|
||||||
|
name = name.replace("/self/", "/")
|
||||||
|
|
||||||
|
# The pooler is a linear layer
|
||||||
|
name = name.replace("pooler/dense", "pooler")
|
||||||
|
|
||||||
|
# The classifier was simplified to predictions from cls/predictions
|
||||||
|
name = name.replace("cls/predictions", "predictions")
|
||||||
|
name = name.replace("predictions/attention", "predictions")
|
||||||
|
|
||||||
|
# Naming was changed to be more explicit
|
||||||
|
name = name.replace("embeddings/attention", "embeddings")
|
||||||
name = name.replace("inner_group_", "albert_layers/")
|
name = name.replace("inner_group_", "albert_layers/")
|
||||||
name = name.replace("group_", "albert_layer_groups/")
|
name = name.replace("group_", "albert_layer_groups/")
|
||||||
name = name.split('/')
|
name = name.split('/')
|
||||||
|
|||||||
Reference in New Issue
Block a user