Support for transformers explicit filename (#38152)
* Support for transformers explicit filename * Tests * Rerun tests
This commit is contained in:
@@ -881,6 +881,7 @@ def _get_resolved_checkpoint_files(
|
||||
user_agent: dict,
|
||||
revision: str,
|
||||
commit_hash: Optional[str],
|
||||
transformers_explicit_filename: Optional[str] = None,
|
||||
) -> Tuple[Optional[List[str]], Optional[Dict]]:
|
||||
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
||||
checkpoints are sharded.
|
||||
@@ -892,7 +893,11 @@ def _get_resolved_checkpoint_files(
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if is_local:
|
||||
if from_tf and os.path.isfile(
|
||||
if transformers_explicit_filename is not None:
|
||||
# If the filename is explicitly defined, load this by default.
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename)
|
||||
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
|
||||
elif from_tf and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||
):
|
||||
# Load from a TF 1.0 checkpoint in priority if from_tf
|
||||
@@ -980,7 +985,10 @@ def _get_resolved_checkpoint_files(
|
||||
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
||||
else:
|
||||
# set correct filename
|
||||
if from_tf:
|
||||
if transformers_explicit_filename is not None:
|
||||
filename = transformers_explicit_filename
|
||||
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
|
||||
elif from_tf:
|
||||
filename = TF2_WEIGHTS_NAME
|
||||
elif from_flax:
|
||||
filename = FLAX_WEIGHTS_NAME
|
||||
@@ -4362,6 +4370,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
model_kwargs = kwargs
|
||||
|
||||
transformers_explicit_filename = getattr(config, "transformers_weights", None)
|
||||
|
||||
if transformers_explicit_filename is not None:
|
||||
if not transformers_explicit_filename.endswith(
|
||||
".safetensors"
|
||||
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
|
||||
raise ValueError(
|
||||
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
|
||||
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
|
||||
f"{transformers_explicit_filename}"
|
||||
)
|
||||
|
||||
pre_quantized = hasattr(config, "quantization_config")
|
||||
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
|
||||
pre_quantized = False
|
||||
@@ -4430,6 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
commit_hash=commit_hash,
|
||||
transformers_explicit_filename=transformers_explicit_filename,
|
||||
)
|
||||
|
||||
is_sharded = sharded_metadata is not None
|
||||
|
||||
Reference in New Issue
Block a user