[PretrainedConfig] Fix save pretrained config for edge case (#7943)
* fix config save * add test * add config class variable and another test * line break * fix fsmt and typo * god am I making many errors today :-/ * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
cc2e312ca3
commit
f34372a9ff
@@ -70,6 +70,7 @@ class EncoderDecoderConfig(PretrainedConfig):
|
|||||||
>>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config)
|
>>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config)
|
||||||
"""
|
"""
|
||||||
model_type = "encoder_decoder"
|
model_type = "encoder_decoder"
|
||||||
|
is_composition = True
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|||||||
@@ -126,9 +126,9 @@ class FSMTConfig(PretrainedConfig):
|
|||||||
# update the defaults from config file
|
# update the defaults from config file
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
langs,
|
langs=["en", "de"],
|
||||||
src_vocab_size,
|
src_vocab_size=42024,
|
||||||
tgt_vocab_size,
|
tgt_vocab_size=42024,
|
||||||
activation_function="relu",
|
activation_function="relu",
|
||||||
d_model=1024,
|
d_model=1024,
|
||||||
max_length=200,
|
max_length=200,
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ RAG_CONFIG_DOC = r"""
|
|||||||
@add_start_docstrings(RAG_CONFIG_DOC)
|
@add_start_docstrings(RAG_CONFIG_DOC)
|
||||||
class RagConfig(PretrainedConfig):
|
class RagConfig(PretrainedConfig):
|
||||||
model_type = "rag"
|
model_type = "rag"
|
||||||
|
is_composition = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ class PretrainedConfig(object):
|
|||||||
Class attributes (overridden by derived classes)
|
Class attributes (overridden by derived classes)
|
||||||
- **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to
|
- **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to
|
||||||
recreate the correct object in :class:`~transformers.AutoConfig`.
|
recreate the correct object in :class:`~transformers.AutoConfig`.
|
||||||
|
- **is_composition** (:obj:`bool`): Whether the config class is composed of multiple
|
||||||
|
sub-configs. In this case the config has to be initialized from two or more configs of
|
||||||
|
type :class:`~transformers.PretrainedConfig` like: :class:`~transformers.EncoderDecoderConfig` or
|
||||||
|
:class:`~RagConfig`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`):
|
name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`):
|
||||||
@@ -145,6 +149,7 @@ class PretrainedConfig(object):
|
|||||||
use BFloat16 scalars (only used by some TensorFlow models).
|
use BFloat16 scalars (only used by some TensorFlow models).
|
||||||
"""
|
"""
|
||||||
model_type: str = ""
|
model_type: str = ""
|
||||||
|
is_composition: bool = False
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
# Attributes with defaults
|
# Attributes with defaults
|
||||||
@@ -476,11 +481,18 @@ class PretrainedConfig(object):
|
|||||||
# get the default config dict
|
# get the default config dict
|
||||||
default_config_dict = PretrainedConfig().to_dict()
|
default_config_dict = PretrainedConfig().to_dict()
|
||||||
|
|
||||||
|
# get class specific config dict
|
||||||
|
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
|
||||||
|
|
||||||
serializable_config_dict = {}
|
serializable_config_dict = {}
|
||||||
|
|
||||||
# only serialize values that differ from the default config
|
# only serialize values that differ from the default config
|
||||||
for key, value in config_dict.items():
|
for key, value in config_dict.items():
|
||||||
if key not in default_config_dict or value != default_config_dict[key]:
|
if (
|
||||||
|
key not in default_config_dict
|
||||||
|
or value != default_config_dict[key]
|
||||||
|
or (key in class_config_dict and value != class_config_dict[key])
|
||||||
|
):
|
||||||
serializable_config_dict[key] = value
|
serializable_config_dict[key] = value
|
||||||
|
|
||||||
return serializable_config_dict
|
return serializable_config_dict
|
||||||
|
|||||||
@@ -66,9 +66,16 @@ class ConfigTester(object):
|
|||||||
self.parent.assertEqual(len(config.id2label), 3)
|
self.parent.assertEqual(len(config.id2label), 3)
|
||||||
self.parent.assertEqual(len(config.label2id), 3)
|
self.parent.assertEqual(len(config.label2id), 3)
|
||||||
|
|
||||||
|
def check_config_can_be_init_without_params(self):
|
||||||
|
if self.config_class.is_composition:
|
||||||
|
return
|
||||||
|
config = self.config_class()
|
||||||
|
self.parent.assertIsNotNone(config)
|
||||||
|
|
||||||
def run_common_tests(self):
|
def run_common_tests(self):
|
||||||
self.create_and_test_config_common_properties()
|
self.create_and_test_config_common_properties()
|
||||||
self.create_and_test_config_to_json_string()
|
self.create_and_test_config_to_json_string()
|
||||||
self.create_and_test_config_to_json_file()
|
self.create_and_test_config_to_json_file()
|
||||||
self.create_and_test_config_from_and_save_pretrained()
|
self.create_and_test_config_from_and_save_pretrained()
|
||||||
self.create_and_test_config_with_num_labels()
|
self.create_and_test_config_with_num_labels()
|
||||||
|
self.check_config_can_be_init_without_params()
|
||||||
|
|||||||
@@ -901,6 +901,15 @@ class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_model_with_attn_mask(*config_and_inputs)
|
self.model_tester.check_model_with_attn_mask(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_config_save(self):
|
||||||
|
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||||
|
config.add_cross_attention = False
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
config.save_pretrained(tmp_dirname)
|
||||||
|
config = ProphetNetConfig.from_pretrained(tmp_dirname)
|
||||||
|
|
||||||
|
self.assertFalse(config.add_cross_attention)
|
||||||
|
|
||||||
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
||||||
def test_fp16_forward(self):
|
def test_fp16_forward(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|||||||
Reference in New Issue
Block a user