From b3d834ae11381ca493da97f717d77d185ca7d780 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 2 Dec 2019 15:01:52 -0500 Subject: [PATCH] Reorganize ALBERT conversion script --- transformers/modeling_albert.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/transformers/modeling_albert.py b/transformers/modeling_albert.py index 5b7b2d3900..49d120ffae 100644 --- a/transformers/modeling_albert.py +++ b/transformers/modeling_albert.py @@ -68,14 +68,36 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): for name, array in zip(names, arrays): original_name = name + + # If saved from the TF HUB module + name = name.replace("module/", "") + + # Renaming and simplifying name = name.replace("ffn_1", "ffn") - name = name.replace("/bert/", "/albert/") - name = name.replace("ffn/intermediate/output", "ffn_output") + name = name.replace("bert/", "albert/") name = name.replace("attention_1", "attention") - name = name.replace("cls/predictions", "predictions") name = name.replace("transform/", "") name = name.replace("LayerNorm_1", "full_layer_layer_norm") - name = name.replace("LayerNorm", "attention/LayerNorm") + name = name.replace("LayerNorm", "attention/LayerNorm") + name = name.replace("transformer/", "") + + # The feed forward layer had an 'intermediate' step which has been abstracted away + name = name.replace("intermediate/dense/", "") + name = name.replace("ffn/intermediate/output/dense/", "ffn_output/") + + # ALBERT attention was split between self and output which have been abstracted away + name = name.replace("/output/", "/") + name = name.replace("/self/", "/") + + # The pooler is a linear layer + name = name.replace("pooler/dense", "pooler") + + # The classifier was simplified to predictions from cls/predictions + name = name.replace("cls/predictions", "predictions") + name = name.replace("predictions/attention", "predictions") + + # Naming was changed to be more explicit + name = name.replace("embeddings/attention", "embeddings") name = name.replace("inner_group_", "albert_layers/") name = name.replace("group_", "albert_layer_groups/") name = name.split('/')