Store transformers version info when saving the model (#9421)
* Store transformers version info when saving the model * Store transformers version info when saving the model * fix format * fix format * fix format * Update src/transformers/configuration_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update configuration_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -21,6 +21,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
|
from . import __version__
|
||||||
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@@ -234,6 +235,9 @@ class PretrainedConfig(object):
|
|||||||
# Name or path to the pretrained checkpoint
|
# Name or path to the pretrained checkpoint
|
||||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||||
|
|
||||||
|
# Drop the transformers version info
|
||||||
|
kwargs.pop("transformers_version", None)
|
||||||
|
|
||||||
# Additional attributes without default values
|
# Additional attributes without default values
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
try:
|
try:
|
||||||
@@ -520,6 +524,7 @@ class PretrainedConfig(object):
|
|||||||
for key, value in config_dict.items():
|
for key, value in config_dict.items():
|
||||||
if (
|
if (
|
||||||
key not in default_config_dict
|
key not in default_config_dict
|
||||||
|
or key == "transformers_version"
|
||||||
or value != default_config_dict[key]
|
or value != default_config_dict[key]
|
||||||
or (key in class_config_dict and value != class_config_dict[key])
|
or (key in class_config_dict and value != class_config_dict[key])
|
||||||
):
|
):
|
||||||
@@ -537,6 +542,10 @@ class PretrainedConfig(object):
|
|||||||
output = copy.deepcopy(self.__dict__)
|
output = copy.deepcopy(self.__dict__)
|
||||||
if hasattr(self.__class__, "model_type"):
|
if hasattr(self.__class__, "model_type"):
|
||||||
output["model_type"] = self.__class__.model_type
|
output["model_type"] = self.__class__.model_type
|
||||||
|
|
||||||
|
# Transformers version when serializing the model
|
||||||
|
output["transformers_version"] = __version__
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def to_json_string(self, use_diff: bool = True) -> str:
|
def to_json_string(self, use_diff: bool = True) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user