Pass kwargs to configuration (#3147)
* Pass kwargs to configuration * Setter * test
This commit is contained in:
@@ -98,6 +98,18 @@ class PretrainedConfig(object):
|
||||
logger.error("Can't set {} with value {} for {}".format(key, value, self))
|
||||
raise err
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
return self._num_labels
|
||||
|
||||
@num_labels.setter
|
||||
def num_labels(self, num_labels):
|
||||
self._num_labels = num_labels
|
||||
self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)}
|
||||
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
||||
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
"""
|
||||
Save a configuration object to the directory `save_directory`, so that it
|
||||
|
||||
@@ -57,8 +57,18 @@ class ConfigTester(object):
|
||||
|
||||
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def create_and_test_config_with_num_labels(self):
|
||||
config = self.config_class(**self.inputs_dict, num_labels=5)
|
||||
self.parent.assertEqual(len(config.id2label), 5)
|
||||
self.parent.assertEqual(len(config.label2id), 5)
|
||||
|
||||
config.num_labels = 3
|
||||
self.parent.assertEqual(len(config.id2label), 3)
|
||||
self.parent.assertEqual(len(config.label2id), 3)
|
||||
|
||||
def run_common_tests(self):
|
||||
self.create_and_test_config_common_properties()
|
||||
self.create_and_test_config_to_json_string()
|
||||
self.create_and_test_config_to_json_file()
|
||||
self.create_and_test_config_from_and_save_pretrained()
|
||||
self.create_and_test_config_with_num_labels()
|
||||
|
||||
Reference in New Issue
Block a user