From 7defc6670fa76e857109e1b99f3e919da8d11f42 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 14 May 2020 17:07:52 -0400 Subject: [PATCH] p_mask in SQuAD pre-processing (#4049) * Better p_mask building * Adressing @mfuntowicz comments --- src/transformers/data/processors/squad.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/transformers/data/processors/squad.py b/src/transformers/data/processors/squad.py index 2ed78ccb12..3b39041fd6 100644 --- a/src/transformers/data/processors/squad.py +++ b/src/transformers/data/processors/squad.py @@ -195,18 +195,22 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q cls_index = span["input_ids"].index(tokenizer.cls_token_id) # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) - # Original TF implem also keep the classification token (set to 0) (not sure why...) - p_mask = np.array(span["token_type_ids"]) - - p_mask = np.minimum(p_mask, 1) - + # Original TF implem also keep the classification token (set to 0) + p_mask = np.ones_like(span["token_type_ids"]) if tokenizer.padding_side == "right": - # Limit positive values to one - p_mask = 1 - p_mask + p_mask[len(truncated_query) + sequence_added_tokens :] = 0 + else: + p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0 - p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1 + pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id) + special_token_indices = np.asarray( + tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True) + ).nonzero() - # Set the CLS index to '0' + p_mask[pad_token_indices] = 1 + p_mask[special_token_indices] = 1 + + # Set the cls index to 0: the CLS index can be used for impossible answers p_mask[cls_index] = 0 span_is_impossible = example.is_impossible