From b08259a12086dbd3b572ea71a7f08ba21518f355 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Fri, 27 Mar 2020 14:59:55 +0000 Subject: [PATCH] run_ner.py / bert-base-multilingual-cased can output empty tokens (#2991) * Use tokenizer.num_added_tokens to count number of added special_tokens instead of hardcoded numbers. Signed-off-by: Morgan Funtowicz * run_ner.py - Do not add a label to the labels_ids if word_tokens is empty. This can happen when using bert-base-multilingual-cased with an input containing an unique space. In this case, the tokenizer will output just an empty word_tokens thus leading to an non-consistent behavior over the labels_ids tokens adding one more tokens than tokens vector. Signed-off-by: Morgan Funtowicz --- examples/ner/utils_ner.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/ner/utils_ner.py b/examples/ner/utils_ner.py index 510749c2f5..bda1b65a7c 100644 --- a/examples/ner/utils_ner.py +++ b/examples/ner/utils_ner.py @@ -112,12 +112,15 @@ def convert_examples_to_features( label_ids = [] for word, label in zip(example.words, example.labels): word_tokens = tokenizer.tokenize(word) - tokens.extend(word_tokens) - # Use the real label id for the first token of the word, and padding ids for the remaining tokens - label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1)) + + # bert-base-multilingual-cased sometimes output "nothing ([]) when calling tokenize with just a space. + if len(word_tokens) > 0: + tokens.extend(word_tokens) + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1)) # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. - special_tokens_count = 3 if sep_token_extra else 2 + special_tokens_count = tokenizer.num_added_tokens() if len(tokens) > max_seq_length - special_tokens_count: tokens = tokens[: (max_seq_length - special_tokens_count)] label_ids = label_ids[: (max_seq_length - special_tokens_count)]