[Flax] Update no init test for Flax v0.7.1 (#28735)

This commit is contained in:
Sanchit Gandhi
2024-01-26 18:20:39 +00:00
committed by GitHub
parent abe0289e6d
commit de13a951b3

View File

@@ -984,7 +984,7 @@ class FlaxModelTesterMixin:
# Check if we params can be properly initialized when calling init_weights
params = model.init_weights(model.key, model.input_shape)
self.assertIsInstance(params, FrozenDict)
assert isinstance(params, (dict, FrozenDict)), f"params are not an instance of {FrozenDict}"
# Check if all required parmas are initialized
keys = set(flatten_dict(unfreeze(params)).keys())
self.assertTrue(all(k in keys for k in model.required_params))