CLI: convert sharded PT models (#17959)
* sharded conversion; add flag to control max hidden error * better hidden name matching * Add test: load TF from PT shards * fix test (PT data must be local)
This commit is contained in:
@@ -27,7 +27,7 @@ from typing import List, Tuple
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@@ -1966,6 +1966,16 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_checkpoint_sharding_local_from_pt(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
_ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
|
||||
model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
|
||||
# the model above is the same as the model below, just a sharded pytorch version.
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_shard_checkpoint(self):
|
||||
# This is the model we will use, total size 340,000 bytes.
|
||||
model = tf.keras.Sequential(
|
||||
|
||||
Reference in New Issue
Block a user