ConvBERT fix torch <> tf weights conversion (#10314)
* convbert conversion test * fin * fin * fin * clean up tf<->pt conversion * remove from_pt Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import ConvBertConfig, ConvBertModel, load_tf_weights_in_convbert
|
||||
from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@@ -30,6 +30,9 @@ def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_f
|
||||
model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path)
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True)
|
||||
tf_model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -343,7 +343,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer):
|
||||
def build(self, input_shape):
|
||||
self.kernel = self.add_weight(
|
||||
"kernel",
|
||||
shape=[self.num_groups, self.group_in_dim, self.group_out_dim],
|
||||
shape=[self.group_out_dim, self.group_in_dim, self.num_groups],
|
||||
initializer=self.kernel_initializer,
|
||||
trainable=True,
|
||||
)
|
||||
@@ -355,7 +355,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer):
|
||||
def call(self, hidden_states):
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
x = tf.transpose(tf.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]), [1, 0, 2])
|
||||
x = tf.matmul(x, self.kernel)
|
||||
x = tf.matmul(x, tf.transpose(self.kernel, [2, 1, 0]))
|
||||
x = tf.transpose(x, [1, 0, 2])
|
||||
x = tf.reshape(x, [batch_size, -1, self.output_size])
|
||||
x = tf.nn.bias_add(value=x, bias=self.bias)
|
||||
|
||||
Reference in New Issue
Block a user