Configs: saner num_labels in configs. (#3967)
This commit is contained in:
@@ -86,11 +86,13 @@ class PretrainedConfig(object):
|
||||
# Fine-tuning task arguments
|
||||
self.architectures = kwargs.pop("architectures", None)
|
||||
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
||||
self.num_labels = kwargs.pop("num_labels", 2)
|
||||
self.id2label = kwargs.pop("id2label", {i: f"LABEL_{i}" for i in range(self.num_labels)})
|
||||
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
||||
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
|
||||
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
|
||||
self.id2label = kwargs.pop("id2label", None)
|
||||
self.label2id = kwargs.pop("label2id", None)
|
||||
if self.id2label is not None:
|
||||
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
||||
# Keys are always strings in JSON so convert ids to int here.
|
||||
else:
|
||||
self.num_labels = kwargs.pop("num_labels", 2)
|
||||
|
||||
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
||||
self.prefix = kwargs.pop("prefix", None)
|
||||
@@ -115,15 +117,12 @@ class PretrainedConfig(object):
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
return self._num_labels
|
||||
return len(self.id2label)
|
||||
|
||||
@num_labels.setter
|
||||
def num_labels(self, num_labels):
|
||||
self._num_labels = num_labels
|
||||
self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)}
|
||||
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
||||
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
|
||||
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user