From 2d458b2c7d6fb1dd5b2361938d1b5bd4c2106479 Mon Sep 17 00:00:00 2001 From: abhishek thakur Date: Wed, 24 Feb 2021 12:55:34 +0100 Subject: [PATCH] ConvBERT fix torch <> tf weights conversion (#10314) * convbert conversion test * fin * fin * fin * clean up tf<->pt conversion * remove from_pt Co-authored-by: patrickvonplaten --- src/transformers/modeling_tf_pytorch_utils.py | 6 +++++- ...onvbert_original_tf1_checkpoint_to_pytorch_and_tf2.py} | 5 ++++- src/transformers/models/convbert/modeling_tf_convbert.py | 4 ++-- tests/test_modeling_tf_convbert.py | 8 +++----- 4 files changed, 14 insertions(+), 9 deletions(-) rename src/transformers/models/convbert/{convert_convbert_original_tf1_checkpoint_to_pytorch.py => convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py} (88%) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 1e3c5e49fd..465af5dd3a 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -56,7 +56,11 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="") tf_name = tf_name[1:] # Remove level zero # When should we transpose the weights - transpose = bool(tf_name[-1] == "kernel" or "emb_projs" in tf_name or "out_projs" in tf_name) + transpose = bool( + tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"] + or "emb_projs" in tf_name + or "out_projs" in tf_name + ) # Convert standard TF2.0 names in PyTorch names if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma": diff --git a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py similarity index 88% rename from src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py rename to src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py index aaeb77784e..cdea57cc24 100644 --- a/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch.py +++ b/src/transformers/models/convbert/convert_convbert_original_tf1_checkpoint_to_pytorch_and_tf2.py @@ -16,7 +16,7 @@ import argparse -from transformers import ConvBertConfig, ConvBertModel, load_tf_weights_in_convbert +from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert from transformers.utils import logging @@ -30,6 +30,9 @@ def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_f model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path) model.save_pretrained(pytorch_dump_path) + tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True) + tf_model.save_pretrained(pytorch_dump_path) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/src/transformers/models/convbert/modeling_tf_convbert.py b/src/transformers/models/convbert/modeling_tf_convbert.py index e5413be45f..b3441e2931 100644 --- a/src/transformers/models/convbert/modeling_tf_convbert.py +++ b/src/transformers/models/convbert/modeling_tf_convbert.py @@ -343,7 +343,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer): def build(self, input_shape): self.kernel = self.add_weight( "kernel", - shape=[self.num_groups, self.group_in_dim, self.group_out_dim], + shape=[self.group_out_dim, self.group_in_dim, self.num_groups], initializer=self.kernel_initializer, trainable=True, ) @@ -355,7 +355,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer): def call(self, hidden_states): batch_size = shape_list(hidden_states)[0] x = tf.transpose(tf.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]), [1, 0, 2]) - x = tf.matmul(x, self.kernel) + x = tf.matmul(x, tf.transpose(self.kernel, [2, 1, 0])) x = tf.transpose(x, [1, 0, 2]) x = tf.reshape(x, [batch_size, -1, self.output_size]) x = tf.nn.bias_add(value=x, bias=self.bias) diff --git a/tests/test_modeling_tf_convbert.py b/tests/test_modeling_tf_convbert.py index 1a7768e700..e882bc64fd 100644 --- a/tests/test_modeling_tf_convbert.py +++ b/tests/test_modeling_tf_convbert.py @@ -399,14 +399,12 @@ class TFConvBertModelIntegrationTest(unittest.TestCase): expected_shape = [1, 6, 768] self.assertEqual(output.shape, expected_shape) - print(output[:, :3, :3]) - expected_slice = tf.constant( [ [ - [-0.10334751, -0.37152207, -0.2682219], - [0.20078957, -0.3918426, -0.78811496], - [0.08000169, -0.509474, -0.59314483], + [-0.03475493, -0.4686034, -0.30638832], + [0.22637248, -0.26988646, -0.7423424], + [0.10324868, -0.45013508, -0.58280784], ] ] )