Fix label attribution in token classification examples (#14055)
This commit is contained in:
@@ -303,6 +303,14 @@ def main():
|
|||||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
|
|
||||||
|
# Map that sends B-Xxx label to its I-Xxx counterpart
|
||||||
|
b_to_i_label = []
|
||||||
|
for idx, label in enumerate(label_list):
|
||||||
|
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
|
||||||
|
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
|
||||||
|
else:
|
||||||
|
b_to_i_label.append(idx)
|
||||||
|
|
||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
#
|
#
|
||||||
# Distributed training:
|
# Distributed training:
|
||||||
@@ -385,7 +393,10 @@ def main():
|
|||||||
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
||||||
# the label_all_tokens flag.
|
# the label_all_tokens flag.
|
||||||
else:
|
else:
|
||||||
label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100)
|
if data_args.label_all_tokens:
|
||||||
|
label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
|
||||||
|
else:
|
||||||
|
label_ids.append(-100)
|
||||||
previous_word_idx = word_idx
|
previous_word_idx = word_idx
|
||||||
|
|
||||||
labels.append(label_ids)
|
labels.append(label_ids)
|
||||||
|
|||||||
@@ -328,6 +328,14 @@ def main():
|
|||||||
label_to_id = {l: i for i, l in enumerate(label_list)}
|
label_to_id = {l: i for i, l in enumerate(label_list)}
|
||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
|
|
||||||
|
# Map that sends B-Xxx label to its I-Xxx counterpart
|
||||||
|
b_to_i_label = []
|
||||||
|
for idx, label in enumerate(label_list):
|
||||||
|
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
|
||||||
|
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
|
||||||
|
else:
|
||||||
|
b_to_i_label.append(idx)
|
||||||
|
|
||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
#
|
#
|
||||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||||
@@ -396,7 +404,10 @@ def main():
|
|||||||
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
||||||
# the label_all_tokens flag.
|
# the label_all_tokens flag.
|
||||||
else:
|
else:
|
||||||
label_ids.append(label_to_id[label[word_idx]] if args.label_all_tokens else -100)
|
if args.label_all_tokens:
|
||||||
|
label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
|
||||||
|
else:
|
||||||
|
label_ids.append(-100)
|
||||||
previous_word_idx = word_idx
|
previous_word_idx = word_idx
|
||||||
|
|
||||||
labels.append(label_ids)
|
labels.append(label_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user