Fix typos in tests (#36547)
Signed-off-by: co63oc <co63oc@users.noreply.github.com>
This commit is contained in:
@@ -680,14 +680,14 @@ class FlaxModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config, _do_init=False)
|
||||
|
||||
# Check that accesing parmas raises an ValueError when _do_init is False
|
||||
# Check that accessing params raises an ValueError when _do_init is False
|
||||
with self.assertRaises(ValueError):
|
||||
params = model.params
|
||||
|
||||
# Check if we params can be properly initialized when calling init_weights
|
||||
params = model.init_weights(model.key, model.input_shape)
|
||||
assert isinstance(params, (dict, FrozenDict)), f"params are not an instance of {FrozenDict}"
|
||||
# Check if all required parmas are initialized
|
||||
# Check if all required params are initialized
|
||||
keys = set(flatten_dict(unfreeze(params)).keys())
|
||||
self.assertTrue(all(k in keys for k in model.required_params))
|
||||
# Check if the shapes match
|
||||
@@ -713,7 +713,7 @@ class FlaxModelTesterMixin:
|
||||
config.return_dict = True
|
||||
|
||||
def _assert_all_params_initialised(model, params):
|
||||
# Check if all required parmas are loaded
|
||||
# Check if all required params are loaded
|
||||
keys = set(flatten_dict(unfreeze(params)).keys())
|
||||
self.assertTrue(all(k in keys for k in model.required_params))
|
||||
# Check if the shapes match
|
||||
@@ -735,11 +735,11 @@ class FlaxModelTesterMixin:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
|
||||
|
||||
# Check that accesing parmas raises an ValueError when _do_init is False
|
||||
# Check that accessing params raises an ValueError when _do_init is False
|
||||
with self.assertRaises(ValueError):
|
||||
params = model.params
|
||||
|
||||
# Check if all required parmas are loaded
|
||||
# Check if all required params are loaded
|
||||
_assert_all_params_initialised(model, params)
|
||||
|
||||
# Check that setting params raises an ValueError when _do_init is False
|
||||
@@ -757,7 +757,7 @@ class FlaxModelTesterMixin:
|
||||
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
|
||||
|
||||
params = model.init_weights(model.key, model.input_shape, params=params)
|
||||
# Check if all required parmas are loaded
|
||||
# Check if all required params are loaded
|
||||
_assert_all_params_initialised(model, params)
|
||||
|
||||
def test_checkpoint_sharding_from_hub(self):
|
||||
|
||||
Reference in New Issue
Block a user