Fix labels stored in model config for token classification examples (#15482)
* Playing * Properly set labels in model config for token classification example * Port to run_ner_no_trainer * Quality
This commit is contained in:
@@ -295,12 +295,15 @@ def main():
|
|||||||
label_list.sort()
|
label_list.sort()
|
||||||
return label_list
|
return label_list
|
||||||
|
|
||||||
if isinstance(features[label_column_name].feature, ClassLabel):
|
# If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
|
||||||
|
# Otherwise, we have to get the list of labels manually.
|
||||||
|
labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
|
||||||
|
if labels_are_int:
|
||||||
label_list = features[label_column_name].feature.names
|
label_list = features[label_column_name].feature.names
|
||||||
label_keys = list(range(len(label_list)))
|
label_to_id = {i: i for i in range(len(label_list))}
|
||||||
else:
|
else:
|
||||||
label_list = get_label_list(raw_datasets["train"][label_column_name])
|
label_list = get_label_list(raw_datasets["train"][label_column_name])
|
||||||
label_keys = label_list
|
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||||
|
|
||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
|
|
||||||
@@ -354,21 +357,26 @@ def main():
|
|||||||
"requirement"
|
"requirement"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Model has labels -> use them.
|
||||||
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||||
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
|
if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
|
||||||
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
# Reorganize `label_list` to match the ordering of the model.
|
||||||
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
|
if labels_are_int:
|
||||||
|
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
|
||||||
|
label_list = [model.config.id2label[i] for i in range(num_labels)]
|
||||||
|
else:
|
||||||
|
label_list = [model.config.id2label[i] for i in range(num_labels)]
|
||||||
|
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||||
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
||||||
"\nIgnoring the model labels as a result.",
|
"\nIgnoring the model labels as a result.",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
label_to_id = {k: i for i, k in enumerate(label_keys)}
|
|
||||||
|
|
||||||
model.config.label2id = label_to_id
|
# Set the correspondences label/ID inside the model config
|
||||||
model.config.id2label = {i: l for l, i in label_to_id.items()}
|
model.config.label2id = {l: i for i, l in enumerate(label_list)}
|
||||||
|
model.config.id2label = {i: l for i, l in enumerate(label_list)}
|
||||||
|
|
||||||
# Map that sends B-Xxx label to its I-Xxx counterpart
|
# Map that sends B-Xxx label to its I-Xxx counterpart
|
||||||
b_to_i_label = []
|
b_to_i_label = []
|
||||||
|
|||||||
@@ -320,12 +320,15 @@ def main():
|
|||||||
label_list.sort()
|
label_list.sort()
|
||||||
return label_list
|
return label_list
|
||||||
|
|
||||||
if isinstance(features[label_column_name].feature, ClassLabel):
|
# If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
|
||||||
|
# Otherwise, we have to get the list of labels manually.
|
||||||
|
labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
|
||||||
|
if labels_are_int:
|
||||||
label_list = features[label_column_name].feature.names
|
label_list = features[label_column_name].feature.names
|
||||||
label_keys = list(range(len(label_list)))
|
label_to_id = {i: i for i in range(len(label_list))}
|
||||||
else:
|
else:
|
||||||
label_list = get_label_list(raw_datasets["train"][label_column_name])
|
label_list = get_label_list(raw_datasets["train"][label_column_name])
|
||||||
label_keys = label_list
|
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||||
|
|
||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
|
|
||||||
@@ -365,21 +368,26 @@ def main():
|
|||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
# Model has labels -> use them.
|
||||||
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||||
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
|
if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
|
||||||
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
# Reorganize `label_list` to match the ordering of the model.
|
||||||
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
|
if labels_are_int:
|
||||||
|
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
|
||||||
|
label_list = [model.config.id2label[i] for i in range(num_labels)]
|
||||||
|
else:
|
||||||
|
label_list = [model.config.id2label[i] for i in range(num_labels)]
|
||||||
|
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||||
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
||||||
"\nIgnoring the model labels as a result.",
|
"\nIgnoring the model labels as a result.",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
label_to_id = {k: i for i, k in enumerate(label_keys)}
|
|
||||||
|
|
||||||
model.config.label2id = label_to_id
|
# Set the correspondences label/ID inside the model config
|
||||||
model.config.id2label = {i: l for l, i in label_to_id.items()}
|
model.config.label2id = {l: i for i, l in enumerate(label_list)}
|
||||||
|
model.config.id2label = {i: l for i, l in enumerate(label_list)}
|
||||||
|
|
||||||
# Map that sends B-Xxx label to its I-Xxx counterpart
|
# Map that sends B-Xxx label to its I-Xxx counterpart
|
||||||
b_to_i_label = []
|
b_to_i_label = []
|
||||||
|
|||||||
Reference in New Issue
Block a user