From 7a9f1b5c99e9a5d1772649d029acdf5160419239 Mon Sep 17 00:00:00 2001 From: Kevin Canwen Xu Date: Wed, 6 Jan 2021 23:34:48 +0800 Subject: [PATCH] Store transformers version info when saving the model (#9421) * Store transformers version info when saving the model * Store transformers version info when saving the model * fix format * fix format * fix format * Update src/transformers/configuration_utils.py Co-authored-by: Lysandre Debut * Update configuration_utils.py Co-authored-by: Lysandre Debut --- src/transformers/configuration_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index d85e62d269..ba53a860cb 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -21,6 +21,7 @@ import json import os from typing import Any, Dict, Tuple, Union +from . import __version__ from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url from .utils import logging @@ -234,6 +235,9 @@ class PretrainedConfig(object): # Name or path to the pretrained checkpoint self._name_or_path = str(kwargs.pop("name_or_path", "")) + # Drop the transformers version info + kwargs.pop("transformers_version", None) + # Additional attributes without default values for key, value in kwargs.items(): try: @@ -520,6 +524,7 @@ class PretrainedConfig(object): for key, value in config_dict.items(): if ( key not in default_config_dict + or key == "transformers_version" or value != default_config_dict[key] or (key in class_config_dict and value != class_config_dict[key]) ): @@ -537,6 +542,10 @@ class PretrainedConfig(object): output = copy.deepcopy(self.__dict__) if hasattr(self.__class__, "model_type"): output["model_type"] = self.__class__.model_type + + # Transformers version when serializing the model + output["transformers_version"] = __version__ + return output def to_json_string(self, use_diff: bool = True) -> str: