cross platform from_pretrained (#20538)
* add support for `from_pt` * add tf_flax utility file * Update src/transformers/modeling_tf_flax_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove flax related modifications * add test * remove FLAX related commits * fixup * remove safetensor todos * revert deletion Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -2127,6 +2127,14 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_checkpoint_sharding_hub_from_pt(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||
# the model above is the same as the model below, just a sharded pytorch version.
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_shard_checkpoint(self):
|
||||
# This is the model we will use, total size 340,000 bytes.
|
||||
model = tf.keras.Sequential(
|
||||
|
||||
Reference in New Issue
Block a user