fix block mask typing (#36661)

* fix block mask typing

* updated

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>

* gemma

* fix

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Arthur
2025-03-12 11:29:11 +01:00
committed by GitHub
parent 89f6956015
commit 2829013d2d
4 changed files with 76 additions and 75 deletions

View File

@@ -849,13 +849,13 @@ def _load_state_dict_into_meta_model(
is_quantized = hf_quantizer is not None
for serialized_param_name, empty_param in state_dict.items():
if serialized_param_name not in expected_keys:
continue
# serialized_param_name is the raw, serialized name
# fixed_param_name is the model's equivalent
fixed_param_name, _ = model.rename_key(serialized_param_name)
if fixed_param_name not in expected_keys:
continue
# we need to use serialized_param_name as file pointer is untouched
if shard_file.endswith(".safetensors"):
param = file_pointer.get_slice(serialized_param_name)