[Flax] Update no init test for Flax v0.7.1 (#28735)
This commit is contained in:
@@ -984,7 +984,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
# Check if we params can be properly initialized when calling init_weights
|
# Check if we params can be properly initialized when calling init_weights
|
||||||
params = model.init_weights(model.key, model.input_shape)
|
params = model.init_weights(model.key, model.input_shape)
|
||||||
self.assertIsInstance(params, FrozenDict)
|
assert isinstance(params, (dict, FrozenDict)), f"params are not an instance of {FrozenDict}"
|
||||||
# Check if all required parmas are initialized
|
# Check if all required parmas are initialized
|
||||||
keys = set(flatten_dict(unfreeze(params)).keys())
|
keys = set(flatten_dict(unfreeze(params)).keys())
|
||||||
self.assertTrue(all(k in keys for k in model.required_params))
|
self.assertTrue(all(k in keys for k in model.required_params))
|
||||||
|
|||||||
Reference in New Issue
Block a user