Load sharded pt to flax (#18419)

* initial commit

* add small test

* add cross pt tf flag to test

* fix quality

* style

* update test with new repo

* fix failing test

* update

* fix wrong param ordering

* style

* update based on review

* update related to recent new caching mechanism

* quality

* Update based on review

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

* quality and style

* Update src/transformers/modeling_flax_utils.py
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Arthur
2022-08-12 09:48:10 +02:00
committed by GitHub
parent c8b6ae858d
commit bce36ee065
3 changed files with 94 additions and 8 deletions

View File

@@ -1099,6 +1099,14 @@ class FlaxModelTesterMixin:
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
@is_pt_flax_cross_test
def test_from_sharded_pt(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
for key, ref_val in flatten_dict(ref_model.params).items():
val = flatten_dict(model.params)[key]
assert np.allclose(np.array(val), np.array(ref_val))
def test_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()