Add qas_id to SquadResult and SquadExample (#3745)
* Add qas_id * Fix incorrect name in squad.py * Make output files optional for squad eval
This commit is contained in:
@@ -251,6 +251,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=span_is_impossible,
|
||||
qas_id=example.qas_id,
|
||||
)
|
||||
)
|
||||
return features
|
||||
@@ -344,9 +345,9 @@ def squad_convert_examples_to_features(
|
||||
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
|
||||
|
||||
if not is_training:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(
|
||||
all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask
|
||||
all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
|
||||
)
|
||||
else:
|
||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||
@@ -368,12 +369,14 @@ def squad_convert_examples_to_features(
|
||||
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
|
||||
|
||||
def gen():
|
||||
for ex in features:
|
||||
for i, ex in enumerate(features):
|
||||
yield (
|
||||
{
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"token_type_ids": ex.token_type_ids,
|
||||
"feature_index": i,
|
||||
"qas_id": ex.qas_id,
|
||||
},
|
||||
{
|
||||
"start_position": ex.start_position,
|
||||
@@ -384,35 +387,44 @@ def squad_convert_examples_to_features(
|
||||
},
|
||||
)
|
||||
|
||||
return tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
(
|
||||
{"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32},
|
||||
{
|
||||
"start_position": tf.int64,
|
||||
"end_position": tf.int64,
|
||||
"cls_index": tf.int64,
|
||||
"p_mask": tf.int32,
|
||||
"is_impossible": tf.int32,
|
||||
},
|
||||
),
|
||||
(
|
||||
{
|
||||
"input_ids": tf.TensorShape([None]),
|
||||
"attention_mask": tf.TensorShape([None]),
|
||||
"token_type_ids": tf.TensorShape([None]),
|
||||
},
|
||||
{
|
||||
"start_position": tf.TensorShape([]),
|
||||
"end_position": tf.TensorShape([]),
|
||||
"cls_index": tf.TensorShape([]),
|
||||
"p_mask": tf.TensorShape([None]),
|
||||
"is_impossible": tf.TensorShape([]),
|
||||
},
|
||||
),
|
||||
# Why have we split the batch into a tuple? PyTorch just has a list of tensors.
|
||||
train_types = (
|
||||
{
|
||||
"input_ids": tf.int32,
|
||||
"attention_mask": tf.int32,
|
||||
"token_type_ids": tf.int32,
|
||||
"feature_index": tf.int64,
|
||||
"qas_id": tf.string,
|
||||
},
|
||||
{
|
||||
"start_position": tf.int64,
|
||||
"end_position": tf.int64,
|
||||
"cls_index": tf.int64,
|
||||
"p_mask": tf.int32,
|
||||
"is_impossible": tf.int32,
|
||||
},
|
||||
)
|
||||
|
||||
return features
|
||||
train_shapes = (
|
||||
{
|
||||
"input_ids": tf.TensorShape([None]),
|
||||
"attention_mask": tf.TensorShape([None]),
|
||||
"token_type_ids": tf.TensorShape([None]),
|
||||
"feature_index": tf.TensorShape([]),
|
||||
"qas_id": tf.TensorShape([]),
|
||||
},
|
||||
{
|
||||
"start_position": tf.TensorShape([]),
|
||||
"end_position": tf.TensorShape([]),
|
||||
"cls_index": tf.TensorShape([]),
|
||||
"p_mask": tf.TensorShape([None]),
|
||||
"is_impossible": tf.TensorShape([]),
|
||||
},
|
||||
)
|
||||
|
||||
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
|
||||
else:
|
||||
return features
|
||||
|
||||
|
||||
class SquadProcessor(DataProcessor):
|
||||
@@ -678,6 +690,7 @@ class SquadFeatures(object):
|
||||
start_position,
|
||||
end_position,
|
||||
is_impossible,
|
||||
qas_id: str = None,
|
||||
):
|
||||
self.input_ids = input_ids
|
||||
self.attention_mask = attention_mask
|
||||
@@ -695,6 +708,7 @@ class SquadFeatures(object):
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
self.qas_id = qas_id
|
||||
|
||||
|
||||
class SquadResult(object):
|
||||
|
||||
Reference in New Issue
Block a user