Fix typos in tests (#36547)

Signed-off-by: co63oc <co63oc@users.noreply.github.com>
This commit is contained in:
co63oc
2025-03-06 07:04:06 +08:00
committed by GitHub
parent 752ef3fd4e
commit 996f512d52
99 changed files with 282 additions and 282 deletions

View File

@@ -680,14 +680,14 @@ class FlaxModelTesterMixin:
for model_class in self.all_model_classes:
model = model_class(config, _do_init=False)
# Check that accesing parmas raises an ValueError when _do_init is False
# Check that accessing params raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
params = model.params
# Check if we params can be properly initialized when calling init_weights
params = model.init_weights(model.key, model.input_shape)
assert isinstance(params, (dict, FrozenDict)), f"params are not an instance of {FrozenDict}"
# Check if all required parmas are initialized
# Check if all required params are initialized
keys = set(flatten_dict(unfreeze(params)).keys())
self.assertTrue(all(k in keys for k in model.required_params))
# Check if the shapes match
@@ -713,7 +713,7 @@ class FlaxModelTesterMixin:
config.return_dict = True
def _assert_all_params_initialised(model, params):
# Check if all required parmas are loaded
# Check if all required params are loaded
keys = set(flatten_dict(unfreeze(params)).keys())
self.assertTrue(all(k in keys for k in model.required_params))
# Check if the shapes match
@@ -735,11 +735,11 @@ class FlaxModelTesterMixin:
model.save_pretrained(tmpdirname)
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
# Check that accesing parmas raises an ValueError when _do_init is False
# Check that accessing params raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
params = model.params
# Check if all required parmas are loaded
# Check if all required params are loaded
_assert_all_params_initialised(model, params)
# Check that setting params raises an ValueError when _do_init is False
@@ -757,7 +757,7 @@ class FlaxModelTesterMixin:
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
params = model.init_weights(model.key, model.input_shape, params=params)
# Check if all required parmas are loaded
# Check if all required params are loaded
_assert_all_params_initialised(model, params)
def test_checkpoint_sharding_from_hub(self):