ConvBERT fix torch <> tf weights conversion (#10314)

* convbert conversion test

* fin

* fin

* fin

* clean up tf<->pt conversion

* remove from_pt

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
abhishek thakur
2021-02-24 12:55:34 +01:00
committed by Lysandre
parent cd48078ce5
commit 0d4c9808c4
4 changed files with 11 additions and 6 deletions

View File

@@ -56,7 +56,11 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
tf_name = tf_name[1:] # Remove level zero tf_name = tf_name[1:] # Remove level zero
# When should we transpose the weights # When should we transpose the weights
transpose = bool(tf_name[-1] == "kernel" or "emb_projs" in tf_name or "out_projs" in tf_name) transpose = bool(
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
or "emb_projs" in tf_name
or "out_projs" in tf_name
)
# Convert standard TF2.0 names in PyTorch names # Convert standard TF2.0 names in PyTorch names
if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma": if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":

View File

@@ -16,7 +16,7 @@
import argparse import argparse
from transformers import ConvBertConfig, ConvBertModel, load_tf_weights_in_convbert from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert
from transformers.utils import logging from transformers.utils import logging
@@ -30,6 +30,9 @@ def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_f
model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path) model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path)
model.save_pretrained(pytorch_dump_path) model.save_pretrained(pytorch_dump_path)
tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True)
tf_model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@@ -425,7 +425,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
self.kernel = self.add_weight( self.kernel = self.add_weight(
"kernel", "kernel",
shape=[self.num_groups, self.group_in_dim, self.group_out_dim], shape=[self.group_out_dim, self.group_in_dim, self.num_groups],
initializer=self.kernel_initializer, initializer=self.kernel_initializer,
trainable=True, trainable=True,
) )
@@ -437,7 +437,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
batch_size = shape_list(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
x = tf.transpose(tf.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]), [1, 0, 2]) x = tf.transpose(tf.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]), [1, 0, 2])
x = tf.matmul(x, self.kernel) x = tf.matmul(x, tf.transpose(self.kernel, [2, 1, 0]))
x = tf.transpose(x, [1, 0, 2]) x = tf.transpose(x, [1, 0, 2])
x = tf.reshape(x, [batch_size, -1, self.output_size]) x = tf.reshape(x, [batch_size, -1, self.output_size])
x = tf.nn.bias_add(value=x, bias=self.bias) x = tf.nn.bias_add(value=x, bias=self.bias)

View File

@@ -384,8 +384,6 @@ class TFConvBertModelIntegrationTest(unittest.TestCase):
expected_shape = [1, 6, 768] expected_shape = [1, 6, 768]
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
print(output[:, :3, :3])
expected_slice = tf.constant( expected_slice = tf.constant(
[ [
[ [