Configs: saner num_labels in configs. (#3967)

This commit is contained in:
Julien Chaumond
2020-05-01 11:28:55 -04:00
committed by GitHub
parent e80be7f1d0
commit 27d55125e6

View File

@@ -86,11 +86,13 @@ class PretrainedConfig(object):
# Fine-tuning task arguments # Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None) self.architectures = kwargs.pop("architectures", None)
self.finetuning_task = kwargs.pop("finetuning_task", None) self.finetuning_task = kwargs.pop("finetuning_task", None)
self.num_labels = kwargs.pop("num_labels", 2) self.id2label = kwargs.pop("id2label", None)
self.id2label = kwargs.pop("id2label", {i: f"LABEL_{i}" for i in range(self.num_labels)}) self.label2id = kwargs.pop("label2id", None)
if self.id2label is not None:
self.id2label = dict((int(key), value) for key, value in self.id2label.items()) self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys()))) # Keys are always strings in JSON so convert ids to int here.
self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) else:
self.num_labels = kwargs.pop("num_labels", 2)
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.prefix = kwargs.pop("prefix", None) self.prefix = kwargs.pop("prefix", None)
@@ -115,15 +117,12 @@ class PretrainedConfig(object):
@property @property
def num_labels(self): def num_labels(self):
return self._num_labels return len(self.id2label)
@num_labels.setter @num_labels.setter
def num_labels(self, num_labels): def num_labels(self, num_labels):
self._num_labels = num_labels self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)}
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" """