From 27d55125e6c3cd5b2589ef9827d835f6d54efc58 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 1 May 2020 11:28:55 -0400 Subject: [PATCH] Configs: saner num_labels in configs. (#3967) --- src/transformers/configuration_utils.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index f42f75e6b7..8aafa6dcf2 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -86,11 +86,13 @@ class PretrainedConfig(object): # Fine-tuning task arguments self.architectures = kwargs.pop("architectures", None) self.finetuning_task = kwargs.pop("finetuning_task", None) - self.num_labels = kwargs.pop("num_labels", 2) - self.id2label = kwargs.pop("id2label", {i: f"LABEL_{i}" for i in range(self.num_labels)}) - self.id2label = dict((int(key), value) for key, value in self.id2label.items()) - self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys()))) - self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) + self.id2label = kwargs.pop("id2label", None) + self.label2id = kwargs.pop("label2id", None) + if self.id2label is not None: + self.id2label = dict((int(key), value) for key, value in self.id2label.items()) + # Keys are always strings in JSON so convert ids to int here. + else: + self.num_labels = kwargs.pop("num_labels", 2) # Tokenizer arguments TODO: eventually tokenizer and models should share the same config self.prefix = kwargs.pop("prefix", None) @@ -115,15 +117,12 @@ class PretrainedConfig(object): @property def num_labels(self): - return self._num_labels + return len(self.id2label) @num_labels.setter def num_labels(self, num_labels): - self._num_labels = num_labels - self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)} - self.id2label = dict((int(key), value) for key, value in self.id2label.items()) + self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)} self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) - self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) def save_pretrained(self, save_directory): """