From 073219b43f432a0e223b4bfebbfb2702547b7acc Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 21 Jan 2020 11:15:22 -0500 Subject: [PATCH] Manage impossible examples SQuAD v2 --- src/transformers/data/processors/squad.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/data/processors/squad.py b/src/transformers/data/processors/squad.py index e2dc3f85b3..2353008b5c 100644 --- a/src/transformers/data/processors/squad.py +++ b/src/transformers/data/processors/squad.py @@ -242,6 +242,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q token_to_orig_map=span["token_to_orig_map"], start_position=start_position, end_position=end_position, + is_impossible=span_is_impossible ) ) return features @@ -332,6 +333,7 @@ def squad_convert_examples_to_features( all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) + all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float) if not is_training: all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) @@ -349,6 +351,7 @@ def squad_convert_examples_to_features( all_end_positions, all_cls_index, all_p_mask, + all_is_impossible ) return features, dataset @@ -369,6 +372,7 @@ def squad_convert_examples_to_features( "end_position": ex.end_position, "cls_index": ex.cls_index, "p_mask": ex.p_mask, + "is_impossible": ex.is_impossible }, ) @@ -376,7 +380,7 @@ def squad_convert_examples_to_features( gen, ( {"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, - {"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32}, + {"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32, "is_impossible": tf.int32}, ), ( { @@ -389,6 +393,7 @@ def squad_convert_examples_to_features( "end_position": tf.TensorShape([]), "cls_index": tf.TensorShape([]), "p_mask": tf.TensorShape([None]), + "is_impossible": tf.TensorShape([]) }, ), ) @@ -658,6 +663,7 @@ class SquadFeatures(object): token_to_orig_map, start_position, end_position, + is_impossible ): self.input_ids = input_ids self.attention_mask = attention_mask @@ -674,7 +680,7 @@ class SquadFeatures(object): self.start_position = start_position self.end_position = end_position - + self.is_impossible = is_impossible class SquadResult(object): """