Apply several ruff SIM rules (#37283)
* Apply ruff SIM118 fix Signed-off-by: cyy <cyyever@outlook.com> * Apply ruff SIM910 fix Signed-off-by: cyy <cyyever@outlook.com> * Apply ruff SIM101 fix Signed-off-by: cyy <cyyever@outlook.com> * Format code Signed-off-by: cyy <cyyever@outlook.com> * More fixes Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -1357,12 +1357,12 @@ def _get_torch_dtype(
|
||||
elif hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
config.torch_dtype = torch_dtype
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
for sub_config_key in config.sub_configs:
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
elif isinstance(torch_dtype, torch.dtype):
|
||||
config.torch_dtype = torch_dtype
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
for sub_config_key in config.sub_configs:
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
elif isinstance(torch_dtype, dict):
|
||||
@@ -1388,7 +1388,7 @@ def _get_torch_dtype(
|
||||
# set fp32 as the default dtype for BC
|
||||
default_dtype = torch.get_default_dtype()
|
||||
config.torch_dtype = default_dtype
|
||||
for key in config.sub_configs.keys():
|
||||
for key in config.sub_configs:
|
||||
value = getattr(config, key)
|
||||
value.torch_dtype = default_dtype
|
||||
|
||||
@@ -1446,7 +1446,7 @@ def _get_device_map(
|
||||
|
||||
# `inferred_max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU,
|
||||
# which we can use to allocate parameters.
|
||||
for device_name in inferred_max_memory.keys():
|
||||
for device_name in inferred_max_memory:
|
||||
if isinstance(device_name, int): # it's a GPU device
|
||||
if is_torch_xpu_available():
|
||||
unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name)
|
||||
@@ -3002,9 +3002,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
||||
)
|
||||
|
||||
all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()}
|
||||
all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules}
|
||||
encoder_layer_pos = 0
|
||||
for name in decoder_modules.keys():
|
||||
for name in decoder_modules:
|
||||
if name.isdigit():
|
||||
encoder_name = str(int(name) + encoder_layer_pos)
|
||||
decoder_name = name
|
||||
@@ -3942,7 +3942,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# Handle the case where some state_dict keys shouldn't be saved
|
||||
if self._keys_to_ignore_on_save is not None:
|
||||
for ignore_key in self._keys_to_ignore_on_save:
|
||||
if ignore_key in state_dict.keys():
|
||||
if ignore_key in state_dict:
|
||||
del state_dict[ignore_key]
|
||||
|
||||
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
|
||||
@@ -4057,7 +4057,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
if (
|
||||
filename.startswith(weights_no_suffix)
|
||||
and os.path.isfile(full_filename)
|
||||
and filename not in state_dict_split.filename_to_tensors.keys()
|
||||
and filename not in state_dict_split.filename_to_tensors
|
||||
and is_main_process
|
||||
and reg.fullmatch(filename_no_suffix) is not None
|
||||
):
|
||||
@@ -5334,7 +5334,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
if device_map is not None:
|
||||
device_map = {k[len(_prefix) :] if k.startswith(_prefix) else k: v for k, v in device_map.items()}
|
||||
# small sanity check: the base model should not contain task-specific head keys
|
||||
task_specific_expected_keys = [s for s in model.state_dict().keys() if not s.startswith(_prefix)]
|
||||
task_specific_expected_keys = [s for s in model.state_dict() if not s.startswith(_prefix)]
|
||||
base_model_expected_keys = list(model_to_load.state_dict().keys())
|
||||
if any(
|
||||
key in task_specific_expected_keys and key not in base_model_expected_keys for key in checkpoint_keys
|
||||
|
||||
Reference in New Issue
Block a user