[Examples] Correct run ner label2id for fine-tuned models (#15017)
* up * up * make style * apply sylvains suggestions * apply changes to accelerate as well * more changes * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8d6acc6c29
commit
457dd4392b
@@ -36,6 +36,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
DataCollatorForTokenClassification,
|
||||
HfArgumentParser,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizerFast,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
@@ -296,20 +297,12 @@ def main():
|
||||
|
||||
if isinstance(features[label_column_name].feature, ClassLabel):
|
||||
label_list = features[label_column_name].feature.names
|
||||
# No need to convert the labels since they are already ints.
|
||||
label_to_id = {i: i for i in range(len(label_list))}
|
||||
label_keys = list(range(len(label_list)))
|
||||
else:
|
||||
label_list = get_label_list(raw_datasets["train"][label_column_name])
|
||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||
num_labels = len(label_list)
|
||||
label_keys = label_list
|
||||
|
||||
# Map that sends B-Xxx label to its I-Xxx counterpart
|
||||
b_to_i_label = []
|
||||
for idx, label in enumerate(label_list):
|
||||
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
|
||||
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
|
||||
else:
|
||||
b_to_i_label.append(idx)
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
@@ -319,8 +312,6 @@ def main():
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
label2id=label_to_id,
|
||||
id2label={i: l for l, i in label_to_id.items()},
|
||||
finetuning_task=data_args.task_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
@@ -363,6 +354,30 @@ def main():
|
||||
"requirement"
|
||||
)
|
||||
|
||||
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
|
||||
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
||||
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
||||
"\nIgnoring the model labels as a result.",
|
||||
)
|
||||
else:
|
||||
label_to_id = {k: i for i, k in enumerate(label_keys)}
|
||||
|
||||
model.config.label2id = label_to_id
|
||||
model.config.id2label = {i: l for l, i in label_to_id.items()}
|
||||
|
||||
# Map that sends B-Xxx label to its I-Xxx counterpart
|
||||
b_to_i_label = []
|
||||
for idx, label in enumerate(label_list):
|
||||
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
|
||||
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
|
||||
else:
|
||||
b_to_i_label.append(idx)
|
||||
|
||||
# Preprocessing the dataset
|
||||
# Padding strategy
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
Reference in New Issue
Block a user