Support for transformers explicit filename (#38152)

* Support for transformers explicit filename

* Tests

* Rerun tests
This commit is contained in:
Lysandre Debut
2025-05-19 14:33:47 +02:00
committed by GitHub
parent dbb9813dff
commit 003deb16f1
3 changed files with 101 additions and 2 deletions

View File

@@ -881,6 +881,7 @@ def _get_resolved_checkpoint_files(
user_agent: dict,
revision: str,
commit_hash: Optional[str],
transformers_explicit_filename: Optional[str] = None,
) -> Tuple[Optional[List[str]], Optional[Dict]]:
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
checkpoints are sharded.
@@ -892,7 +893,11 @@ def _get_resolved_checkpoint_files(
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
if from_tf and os.path.isfile(
if transformers_explicit_filename is not None:
# If the filename is explicitly defined, load this by default.
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename)
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
elif from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
# Load from a TF 1.0 checkpoint in priority if from_tf
@@ -980,7 +985,10 @@ def _get_resolved_checkpoint_files(
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename
if from_tf:
if transformers_explicit_filename is not None:
filename = transformers_explicit_filename
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
elif from_tf:
filename = TF2_WEIGHTS_NAME
elif from_flax:
filename = FLAX_WEIGHTS_NAME
@@ -4362,6 +4370,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
model_kwargs = kwargs
transformers_explicit_filename = getattr(config, "transformers_weights", None)
if transformers_explicit_filename is not None:
if not transformers_explicit_filename.endswith(
".safetensors"
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
raise ValueError(
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
f"{transformers_explicit_filename}"
)
pre_quantized = hasattr(config, "quantization_config")
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
pre_quantized = False
@@ -4430,6 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
user_agent=user_agent,
revision=revision,
commit_hash=commit_hash,
transformers_explicit_filename=transformers_explicit_filename,
)
is_sharded = sharded_metadata is not None