CLI: convert sharded PT models (#17959)

* sharded conversion; add flag to control max hidden error

* better hidden name matching

* Add test: load TF from PT shards

* fix test (PT data must be local)
This commit is contained in:
Joao Gante
2022-06-30 16:51:03 +01:00
committed by GitHub
parent f25457b273
commit 91e1f24ef3
4 changed files with 102 additions and 38 deletions

View File

@@ -27,7 +27,7 @@ from typing import List, Tuple
from datasets import Dataset
from huggingface_hub import HfFolder, delete_repo, set_access_token
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available
from transformers.configuration_utils import PretrainedConfig
@@ -1966,6 +1966,16 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
@is_pt_tf_cross_test
def test_checkpoint_sharding_local_from_pt(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
# the model above is the same as the model below, just a sharded pytorch version.
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_shard_checkpoint(self):
# This is the model we will use, total size 340,000 bytes.
model = tf.keras.Sequential(