From 15579e2d553df6588caf8ecc60f5f26a6d144df3 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 21 Jan 2020 11:36:46 -0500 Subject: [PATCH] [SQuAD v2] Code quality --- src/transformers/data/processors/squad.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/transformers/data/processors/squad.py b/src/transformers/data/processors/squad.py index 2353008b5c..f2e63e9394 100644 --- a/src/transformers/data/processors/squad.py +++ b/src/transformers/data/processors/squad.py @@ -242,7 +242,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q token_to_orig_map=span["token_to_orig_map"], start_position=start_position, end_position=end_position, - is_impossible=span_is_impossible + is_impossible=span_is_impossible, ) ) return features @@ -351,7 +351,7 @@ def squad_convert_examples_to_features( all_end_positions, all_cls_index, all_p_mask, - all_is_impossible + all_is_impossible, ) return features, dataset @@ -372,7 +372,7 @@ def squad_convert_examples_to_features( "end_position": ex.end_position, "cls_index": ex.cls_index, "p_mask": ex.p_mask, - "is_impossible": ex.is_impossible + "is_impossible": ex.is_impossible, }, ) @@ -380,7 +380,13 @@ def squad_convert_examples_to_features( gen, ( {"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, - {"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32, "is_impossible": tf.int32}, + { + "start_position": tf.int64, + "end_position": tf.int64, + "cls_index": tf.int64, + "p_mask": tf.int32, + "is_impossible": tf.int32, + }, ), ( { @@ -393,7 +399,7 @@ def squad_convert_examples_to_features( "end_position": tf.TensorShape([]), "cls_index": tf.TensorShape([]), "p_mask": tf.TensorShape([None]), - "is_impossible": tf.TensorShape([]) + "is_impossible": tf.TensorShape([]), }, ), ) @@ -663,7 +669,7 @@ class SquadFeatures(object): token_to_orig_map, start_position, end_position, - is_impossible + is_impossible, ): self.input_ids = input_ids self.attention_mask = attention_mask @@ -682,6 +688,7 @@ class SquadFeatures(object): self.end_position = end_position self.is_impossible = is_impossible + class SquadResult(object): """ Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.