Set weights_only in torch.load (#36991)
This commit is contained in:
@@ -277,12 +277,7 @@ def convert_pt_checkpoint_to_tf(
|
||||
if compare_with_pt_model:
|
||||
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
||||
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
state_dict = torch.load(
|
||||
pytorch_checkpoint_path,
|
||||
map_location="cpu",
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu", weights_only=True)
|
||||
pt_model = pt_model_class.from_pretrained(
|
||||
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
||||
)
|
||||
|
||||
@@ -148,7 +148,7 @@ class SquadDataset(Dataset):
|
||||
with FileLock(lock_path):
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
start = time.time()
|
||||
self.old_features = torch.load(cached_features_file)
|
||||
self.old_features = torch.load(cached_features_file, weights_only=True)
|
||||
|
||||
# Legacy cache files have only features, while new cache files
|
||||
# will have dataset and examples also.
|
||||
|
||||
@@ -71,8 +71,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
||||
)
|
||||
raise
|
||||
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
|
||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||
|
||||
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
||||
@@ -248,8 +247,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
||||
flax_state_dict = {}
|
||||
for shard_file in shard_filenames:
|
||||
# load using msgpack utils
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
|
||||
pt_state_dict = torch.load(shard_file, weights_only=True)
|
||||
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
|
||||
pt_state_dict = {
|
||||
k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
|
||||
|
||||
@@ -198,8 +198,7 @@ def load_pytorch_checkpoint_in_tf2_model(
|
||||
if pt_path.endswith(".safetensors"):
|
||||
state_dict = safe_load_file(pt_path)
|
||||
else:
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
||||
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
|
||||
|
||||
pt_state_dict.update(state_dict)
|
||||
|
||||
|
||||
@@ -504,8 +504,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = loader(os.path.join(folder, shard_file))
|
||||
@@ -598,11 +597,10 @@ def load_state_dict(
|
||||
and is_zipfile(checkpoint_file)
|
||||
):
|
||||
extra_args = {"mmap": True}
|
||||
weights_only_kwarg = {"weights_only": weights_only}
|
||||
return torch.load(
|
||||
checkpoint_file,
|
||||
map_location=map_location,
|
||||
**weights_only_kwarg,
|
||||
weights_only=weights_only,
|
||||
**extra_args,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1216,7 +1214,7 @@ def _get_torch_dtype(
|
||||
weights_only: bool,
|
||||
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
||||
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
|
||||
infered dtype. We do the following:
|
||||
inferred dtype. We do the following:
|
||||
1. If torch_dtype is not None, we use that dtype
|
||||
2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
||||
|
||||
@@ -207,7 +207,7 @@ def convert_wav2vec2_checkpoint(
|
||||
hf_wav2vec = Data2VecAudioModel(config)
|
||||
data2vec_checkpoint_dir = os.path.dirname(checkpoint_path)
|
||||
|
||||
state_dict = torch.load(checkpoint_path)
|
||||
state_dict = torch.load(checkpoint_path, weights_only=True)
|
||||
state_dict["model"]["final_proj.weight"] = state_dict["model"].pop("final_proj.0.weight")
|
||||
state_dict["model"]["final_proj.bias"] = state_dict["model"].pop("final_proj.0.bias")
|
||||
converted_ckpt = os.path.join(data2vec_checkpoint_dir, "converted.pt")
|
||||
|
||||
@@ -121,7 +121,7 @@ def convert_phi_weights(
|
||||
if model_path.endswith("safetensors"):
|
||||
loaded_weights = safetensors.torch.load_file(model_path, device=device)
|
||||
else:
|
||||
loaded_weights = torch.load(model_path, map_location=device)
|
||||
loaded_weights = torch.load(model_path, map_location=device, weights_only=True)
|
||||
model_checkpoint.update(**loaded_weights)
|
||||
|
||||
model_type = model_name.split("/")[1] # phi-1 or phi-1_5 or phi-2
|
||||
|
||||
@@ -1589,11 +1589,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
state_dict = torch.load(
|
||||
weight_path,
|
||||
map_location="cpu",
|
||||
**weights_only_kwarg,
|
||||
weights_only=True,
|
||||
)
|
||||
|
||||
except EnvironmentError:
|
||||
|
||||
@@ -2820,7 +2820,6 @@ class Trainer:
|
||||
)
|
||||
|
||||
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
# If the model is on the GPU, it still works!
|
||||
if is_sagemaker_mp_enabled():
|
||||
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
|
||||
@@ -2836,11 +2835,7 @@ class Trainer:
|
||||
logger.warning(
|
||||
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
|
||||
)
|
||||
state_dict = torch.load(
|
||||
weights_file,
|
||||
map_location="cpu",
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
|
||||
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
|
||||
state_dict["_smp_is_partial"] = False
|
||||
load_result = model.load_state_dict(state_dict, strict=True)
|
||||
@@ -2859,11 +2854,7 @@ class Trainer:
|
||||
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
|
||||
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(
|
||||
weights_file,
|
||||
map_location="cpu",
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
|
||||
|
||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||
# which takes *args instead of **kwargs
|
||||
@@ -2941,7 +2932,6 @@ class Trainer:
|
||||
or os.path.exists(best_safe_adapter_model_path)
|
||||
):
|
||||
has_been_loaded = True
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
if is_sagemaker_mp_enabled():
|
||||
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
|
||||
# If the 'user_content.pt' file exists, load with the new smp api.
|
||||
@@ -2958,11 +2948,7 @@ class Trainer:
|
||||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
||||
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(
|
||||
best_model_path,
|
||||
map_location="cpu",
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
|
||||
|
||||
state_dict["_smp_is_partial"] = False
|
||||
load_result = model.load_state_dict(state_dict, strict=True)
|
||||
@@ -3017,11 +3003,7 @@ class Trainer:
|
||||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
||||
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(
|
||||
best_model_path,
|
||||
map_location="cpu",
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
|
||||
|
||||
# If the model is on the GPU, it still works!
|
||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||
@@ -3142,7 +3124,7 @@ class Trainer:
|
||||
return
|
||||
|
||||
with safe_globals():
|
||||
checkpoint_rng_state = torch.load(rng_file)
|
||||
checkpoint_rng_state = torch.load(rng_file, weights_only=True)
|
||||
random.setstate(checkpoint_rng_state["python"])
|
||||
np.random.set_state(checkpoint_rng_state["numpy"])
|
||||
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
||||
@@ -3375,7 +3357,9 @@ class Trainer:
|
||||
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
self.lr_scheduler.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
|
||||
)
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
return
|
||||
|
||||
@@ -3410,13 +3394,18 @@ class Trainer:
|
||||
checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
|
||||
),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
# We only need `optimizer` when resuming from checkpoint
|
||||
optimizer_state = optimizer_state["optimizer"]
|
||||
else:
|
||||
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
||||
optimizer_state = torch.load(
|
||||
os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
|
||||
lr_scheduler_state = torch.load(
|
||||
os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True
|
||||
)
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
||||
@@ -3458,10 +3447,14 @@ class Trainer:
|
||||
)
|
||||
else:
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||
torch.load(
|
||||
os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True
|
||||
)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
self.lr_scheduler.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
|
||||
)
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
def _save_scaler(self, output_dir):
|
||||
@@ -3496,13 +3489,17 @@ class Trainer:
|
||||
# Load in scaler states
|
||||
if is_torch_xla_available():
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
scaler_state = torch.load(os.path.join(checkpoint, SCALER_NAME), map_location="cpu")
|
||||
scaler_state = torch.load(
|
||||
os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True
|
||||
)
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
xm.send_cpu_data_to_device(scaler_state, self.args.device)
|
||||
self.accelerator.scaler.load_state_dict(scaler_state)
|
||||
else:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.accelerator.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
||||
self.accelerator.scaler.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True)
|
||||
)
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
def _load_callback_state(self):
|
||||
|
||||
Reference in New Issue
Block a user