From 457dd4392bb0caa8c30b33fc458e6d5e2d5443c5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 24 Jan 2022 21:18:04 +0100 Subject: [PATCH] [Examples] Correct run ner label2id for fine-tuned models (#15017) * up * up * make style * apply sylvains suggestions * apply changes to accelerate as well * more changes * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../pytorch/token-classification/run_ner.py | 41 +++++++++++++------ .../run_ner_no_trainer.py | 39 +++++++++++++----- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index 831ba9a156..224eaed7b4 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -36,6 +36,7 @@ from transformers import ( AutoTokenizer, DataCollatorForTokenClassification, HfArgumentParser, + PretrainedConfig, PreTrainedTokenizerFast, Trainer, TrainingArguments, @@ -296,20 +297,12 @@ def main(): if isinstance(features[label_column_name].feature, ClassLabel): label_list = features[label_column_name].feature.names - # No need to convert the labels since they are already ints. - label_to_id = {i: i for i in range(len(label_list))} + label_keys = list(range(len(label_list))) else: label_list = get_label_list(raw_datasets["train"][label_column_name]) - label_to_id = {l: i for i, l in enumerate(label_list)} - num_labels = len(label_list) + label_keys = label_list - # Map that sends B-Xxx label to its I-Xxx counterpart - b_to_i_label = [] - for idx, label in enumerate(label_list): - if label.startswith("B-") and label.replace("B-", "I-") in label_list: - b_to_i_label.append(label_list.index(label.replace("B-", "I-"))) - else: - b_to_i_label.append(idx) + num_labels = len(label_list) # Load pretrained model and tokenizer # @@ -319,8 +312,6 @@ def main(): config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, - label2id=label_to_id, - id2label={i: l for l, i in label_to_id.items()}, finetuning_task=data_args.task_name, cache_dir=model_args.cache_dir, revision=model_args.model_revision, @@ -363,6 +354,30 @@ def main(): "requirement" ) + if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: + label_name_to_id = {k: v for k, v in model.config.label2id.items()} + if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): + label_to_id = {k: int(label_name_to_id[k]) for k in label_keys} + else: + logger.warning( + "Your model seems to have been trained with labels, but they don't match the dataset: ", + f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + "\nIgnoring the model labels as a result.", + ) + else: + label_to_id = {k: i for i, k in enumerate(label_keys)} + + model.config.label2id = label_to_id + model.config.id2label = {i: l for l, i in label_to_id.items()} + + # Map that sends B-Xxx label to its I-Xxx counterpart + b_to_i_label = [] + for idx, label in enumerate(label_list): + if label.startswith("B-") and label.replace("B-", "I-") in label_list: + b_to_i_label.append(label_list.index(label.replace("B-", "I-"))) + else: + b_to_i_label.append(idx) + # Preprocessing the dataset # Padding strategy padding = "max_length" if data_args.pad_to_max_length else False diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py index 50b7645182..af1959aa51 100755 --- a/examples/pytorch/token-classification/run_ner_no_trainer.py +++ b/examples/pytorch/token-classification/run_ner_no_trainer.py @@ -42,6 +42,7 @@ from transformers import ( AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, + PretrainedConfig, SchedulerType, default_data_collator, get_scheduler, @@ -321,20 +322,12 @@ def main(): if isinstance(features[label_column_name].feature, ClassLabel): label_list = features[label_column_name].feature.names - # No need to convert the labels since they are already ints. - label_to_id = {i: i for i in range(len(label_list))} + label_keys = list(range(len(label_list))) else: label_list = get_label_list(raw_datasets["train"][label_column_name]) - label_to_id = {l: i for i, l in enumerate(label_list)} - num_labels = len(label_list) + label_keys = label_list - # Map that sends B-Xxx label to its I-Xxx counterpart - b_to_i_label = [] - for idx, label in enumerate(label_list): - if label.startswith("B-") and label.replace("B-", "I-") in label_list: - b_to_i_label.append(label_list.index(label.replace("B-", "I-"))) - else: - b_to_i_label.append(idx) + num_labels = len(label_list) # Load pretrained model and tokenizer # @@ -372,6 +365,30 @@ def main(): model.resize_token_embeddings(len(tokenizer)) + if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: + label_name_to_id = {k: v for k, v in model.config.label2id.items()} + if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): + label_to_id = {k: int(label_name_to_id[k]) for k in label_keys} + else: + logger.warning( + "Your model seems to have been trained with labels, but they don't match the dataset: ", + f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + "\nIgnoring the model labels as a result.", + ) + else: + label_to_id = {k: i for i, k in enumerate(label_keys)} + + model.config.label2id = label_to_id + model.config.id2label = {i: l for l, i in label_to_id.items()} + + # Map that sends B-Xxx label to its I-Xxx counterpart + b_to_i_label = [] + for idx, label in enumerate(label_list): + if label.startswith("B-") and label.replace("B-", "I-") in label_list: + b_to_i_label.append(label_list.index(label.replace("B-", "I-"))) + else: + b_to_i_label.append(idx) + # Preprocessing the datasets. # First we tokenize all the texts. padding = "max_length" if args.pad_to_max_length else False