From b623ddc0002aebe32e2b7a1203a6acbed61bf9a8 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 5 Mar 2020 17:16:57 -0500 Subject: [PATCH] Pass kwargs to configuration (#3147) * Pass kwargs to configuration * Setter * test --- src/transformers/configuration_utils.py | 12 ++++++++++++ tests/test_configuration_common.py | 10 ++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index d8cd0fe3e9..5ce23e2c88 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -98,6 +98,18 @@ class PretrainedConfig(object): logger.error("Can't set {} with value {} for {}".format(key, value, self)) raise err + @property + def num_labels(self): + return self._num_labels + + @num_labels.setter + def num_labels(self, num_labels): + self._num_labels = num_labels + self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)} + self.id2label = dict((int(key), value) for key, value in self.id2label.items()) + self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) + self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) + def save_pretrained(self, save_directory): """ Save a configuration object to the directory `save_directory`, so that it diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 471f0f012d..7498ae6caf 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -57,8 +57,18 @@ class ConfigTester(object): self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) + def create_and_test_config_with_num_labels(self): + config = self.config_class(**self.inputs_dict, num_labels=5) + self.parent.assertEqual(len(config.id2label), 5) + self.parent.assertEqual(len(config.label2id), 5) + + config.num_labels = 3 + self.parent.assertEqual(len(config.id2label), 3) + self.parent.assertEqual(len(config.label2id), 3) + def run_common_tests(self): self.create_and_test_config_common_properties() self.create_and_test_config_to_json_string() self.create_and_test_config_to_json_file() self.create_and_test_config_from_and_save_pretrained() + self.create_and_test_config_with_num_labels()