Load sharded pt to flax (#18419)
* initial commit * add small test * add cross pt tf flag to test * fix quality * style * update test with new repo * fix failing test * update * fix wrong param ordering * style * update based on review * update related to recent new caching mechanism * quality * Update based on review Co-authored-by: sgugger <sylvain.gugger@gmail.com> * quality and style * Update src/transformers/modeling_flax_utils.py Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1099,6 +1099,14 @@ class FlaxModelTesterMixin:
|
||||
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
|
||||
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_from_sharded_pt(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only")
|
||||
for key, ref_val in flatten_dict(ref_model.params).items():
|
||||
val = flatten_dict(model.params)[key]
|
||||
assert np.allclose(np.array(val), np.array(ref_val))
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user