Fix GroupedLinearLayer in TF ConvBERT (#9972)
This commit is contained in:
@@ -435,9 +435,10 @@ class GroupedLinearLayer(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
batch_size = shape_list(tensor=hidden_states)[1]
|
batch_size = shape_list(hidden_states)[0]
|
||||||
x = tf.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim])
|
x = tf.transpose(tf.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]), [1, 0, 2])
|
||||||
x = tf.matmul(a=x, b=self.kernel, transpose_b=True)
|
x = tf.matmul(x, self.kernel)
|
||||||
|
x = tf.transpose(x, [1, 0, 2])
|
||||||
x = tf.reshape(x, [batch_size, -1, self.output_size])
|
x = tf.reshape(x, [batch_size, -1, self.output_size])
|
||||||
x = tf.nn.bias_add(value=x, bias=self.bias)
|
x = tf.nn.bias_add(value=x, bias=self.bias)
|
||||||
return x
|
return x
|
||||||
|
|||||||
Reference in New Issue
Block a user