ConvBERT fix torch <> tf weights conversion (#10314)

* convbert conversion test

* fin

* fin

* fin

* clean up tf<->pt conversion

* remove from_pt

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
abhishek thakur
2021-02-24 12:55:34 +01:00
committed by GitHub
parent 3437d12134
commit 2d458b2c7d
4 changed files with 14 additions and 9 deletions

View File

@@ -399,14 +399,12 @@ class TFConvBertModelIntegrationTest(unittest.TestCase):
expected_shape = [1, 6, 768]
self.assertEqual(output.shape, expected_shape)
print(output[:, :3, :3])
expected_slice = tf.constant(
[
[
[-0.10334751, -0.37152207, -0.2682219],
[0.20078957, -0.3918426, -0.78811496],
[0.08000169, -0.509474, -0.59314483],
[-0.03475493, -0.4686034, -0.30638832],
[0.22637248, -0.26988646, -0.7423424],
[0.10324868, -0.45013508, -0.58280784],
]
]
)