Fix run_glue evaluation when model has a label correspondence (#10401)
This commit is contained in:
@@ -324,7 +324,7 @@ def main():
|
|||||||
# Some have all caps in their config, some don't.
|
# Some have all caps in their config, some don't.
|
||||||
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
|
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
|
||||||
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
||||||
label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
|
label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
|
||||||
else:
|
else:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||||
@@ -350,7 +350,7 @@ def main():
|
|||||||
|
|
||||||
# Map labels to IDs (not necessary for GLUE tasks)
|
# Map labels to IDs (not necessary for GLUE tasks)
|
||||||
if label_to_id is not None and "label" in examples:
|
if label_to_id is not None and "label" in examples:
|
||||||
result["label"] = [label_to_id[l] for l in examples["label"]]
|
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
|
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
|
||||||
|
|||||||
@@ -289,6 +289,7 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
@num_labels.setter
|
@num_labels.setter
|
||||||
def num_labels(self, num_labels: int):
|
def num_labels(self, num_labels: int):
|
||||||
|
if self.id2label is None or len(self.id2label) != num_labels:
|
||||||
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
|
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
|
||||||
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user