Pass kwargs to configuration (#3147)
* Pass kwargs to configuration * Setter * test
This commit is contained in:
@@ -98,6 +98,18 @@ class PretrainedConfig(object):
|
|||||||
logger.error("Can't set {} with value {} for {}".format(key, value, self))
|
logger.error("Can't set {} with value {} for {}".format(key, value, self))
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_labels(self):
|
||||||
|
return self._num_labels
|
||||||
|
|
||||||
|
@num_labels.setter
|
||||||
|
def num_labels(self, num_labels):
|
||||||
|
self._num_labels = num_labels
|
||||||
|
self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)}
|
||||||
|
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
||||||
|
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||||
|
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
|
||||||
|
|
||||||
def save_pretrained(self, save_directory):
|
def save_pretrained(self, save_directory):
|
||||||
"""
|
"""
|
||||||
Save a configuration object to the directory `save_directory`, so that it
|
Save a configuration object to the directory `save_directory`, so that it
|
||||||
|
|||||||
@@ -57,8 +57,18 @@ class ConfigTester(object):
|
|||||||
|
|
||||||
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||||
|
|
||||||
|
def create_and_test_config_with_num_labels(self):
|
||||||
|
config = self.config_class(**self.inputs_dict, num_labels=5)
|
||||||
|
self.parent.assertEqual(len(config.id2label), 5)
|
||||||
|
self.parent.assertEqual(len(config.label2id), 5)
|
||||||
|
|
||||||
|
config.num_labels = 3
|
||||||
|
self.parent.assertEqual(len(config.id2label), 3)
|
||||||
|
self.parent.assertEqual(len(config.label2id), 3)
|
||||||
|
|
||||||
def run_common_tests(self):
|
def run_common_tests(self):
|
||||||
self.create_and_test_config_common_properties()
|
self.create_and_test_config_common_properties()
|
||||||
self.create_and_test_config_to_json_string()
|
self.create_and_test_config_to_json_string()
|
||||||
self.create_and_test_config_to_json_file()
|
self.create_and_test_config_to_json_file()
|
||||||
self.create_and_test_config_from_and_save_pretrained()
|
self.create_and_test_config_from_and_save_pretrained()
|
||||||
|
self.create_and_test_config_with_num_labels()
|
||||||
|
|||||||
Reference in New Issue
Block a user