From 45cac3fade34cb7134b080c5060c250f810db5e2 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 2 Feb 2022 14:23:43 -0500 Subject: [PATCH] Fix labels stored in model config for token classification examples (#15482) * Playing * Properly set labels in model config for token classification example * Port to run_ner_no_trainer * Quality --- .../pytorch/token-classification/run_ner.py | 30 ++++++++++++------- .../run_ner_no_trainer.py | 30 ++++++++++++------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index 8fb0a9ba6d..c5718e82fc 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -295,12 +295,15 @@ def main(): label_list.sort() return label_list - if isinstance(features[label_column_name].feature, ClassLabel): + # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere. + # Otherwise, we have to get the list of labels manually. + labels_are_int = isinstance(features[label_column_name].feature, ClassLabel) + if labels_are_int: label_list = features[label_column_name].feature.names - label_keys = list(range(len(label_list))) + label_to_id = {i: i for i in range(len(label_list))} else: label_list = get_label_list(raw_datasets["train"][label_column_name]) - label_keys = label_list + label_to_id = {l: i for i, l in enumerate(label_list)} num_labels = len(label_list) @@ -354,21 +357,26 @@ def main(): "requirement" ) + # Model has labels -> use them. if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: - label_name_to_id = {k: v for k, v in model.config.label2id.items()} - if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): - label_to_id = {k: int(label_name_to_id[k]) for k in label_keys} + if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)): + # Reorganize `label_list` to match the ordering of the model. + if labels_are_int: + label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)} + label_list = [model.config.id2label[i] for i in range(num_labels)] + else: + label_list = [model.config.id2label[i] for i in range(num_labels)] + label_to_id = {l: i for i, l in enumerate(label_list)} else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", - f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) - else: - label_to_id = {k: i for i, k in enumerate(label_keys)} - model.config.label2id = label_to_id - model.config.id2label = {i: l for l, i in label_to_id.items()} + # Set the correspondences label/ID inside the model config + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {i: l for i, l in enumerate(label_list)} # Map that sends B-Xxx label to its I-Xxx counterpart b_to_i_label = [] diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py index af1959aa51..e292331ea4 100755 --- a/examples/pytorch/token-classification/run_ner_no_trainer.py +++ b/examples/pytorch/token-classification/run_ner_no_trainer.py @@ -320,12 +320,15 @@ def main(): label_list.sort() return label_list - if isinstance(features[label_column_name].feature, ClassLabel): + # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere. + # Otherwise, we have to get the list of labels manually. + labels_are_int = isinstance(features[label_column_name].feature, ClassLabel) + if labels_are_int: label_list = features[label_column_name].feature.names - label_keys = list(range(len(label_list))) + label_to_id = {i: i for i in range(len(label_list))} else: label_list = get_label_list(raw_datasets["train"][label_column_name]) - label_keys = label_list + label_to_id = {l: i for i, l in enumerate(label_list)} num_labels = len(label_list) @@ -365,21 +368,26 @@ def main(): model.resize_token_embeddings(len(tokenizer)) + # Model has labels -> use them. if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: - label_name_to_id = {k: v for k, v in model.config.label2id.items()} - if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): - label_to_id = {k: int(label_name_to_id[k]) for k in label_keys} + if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)): + # Reorganize `label_list` to match the ordering of the model. + if labels_are_int: + label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)} + label_list = [model.config.id2label[i] for i in range(num_labels)] + else: + label_list = [model.config.id2label[i] for i in range(num_labels)] + label_to_id = {l: i for i, l in enumerate(label_list)} else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", - f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) - else: - label_to_id = {k: i for i, k in enumerate(label_keys)} - model.config.label2id = label_to_id - model.config.id2label = {i: l for l, i in label_to_id.items()} + # Set the correspondences label/ID inside the model config + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {i: l for i, l in enumerate(label_list)} # Map that sends B-Xxx label to its I-Xxx counterpart b_to_i_label = []