Add qas_id to SquadResult and SquadExample (#3745)
* Add qas_id * Fix incorrect name in squad.py * Make output files optional for squad eval
This commit is contained in:
@@ -307,7 +307,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
|
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
|
||||||
del inputs["token_type_ids"]
|
del inputs["token_type_ids"]
|
||||||
|
|
||||||
example_indices = batch[3]
|
feature_indices = batch[3]
|
||||||
|
|
||||||
# XLNet and XLM use more arguments for their predictions
|
# XLNet and XLM use more arguments for their predictions
|
||||||
if args.model_type in ["xlnet", "xlm"]:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
@@ -320,8 +320,9 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
for i, example_index in enumerate(example_indices):
|
for i, feature_index in enumerate(feature_indices):
|
||||||
eval_feature = features[example_index.item()]
|
# TODO: i and feature_index are the same number! Simplify by removing enumerate?
|
||||||
|
eval_feature = features[feature_index.item()]
|
||||||
unique_id = int(eval_feature.unique_id)
|
unique_id = int(eval_feature.unique_id)
|
||||||
|
|
||||||
output = [to_list(output[i]) for output in outputs]
|
output = [to_list(output[i]) for output in outputs]
|
||||||
|
|||||||
@@ -384,8 +384,12 @@ def compute_predictions_logits(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
):
|
):
|
||||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
if output_prediction_file:
|
||||||
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
logger.info(f"Writing predictions to: {output_prediction_file}")
|
||||||
|
if output_nbest_file:
|
||||||
|
logger.info(f"Writing nbest to: {output_nbest_file}")
|
||||||
|
if output_null_log_odds_file and version_2_with_negative:
|
||||||
|
logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
|
||||||
|
|
||||||
example_index_to_features = collections.defaultdict(list)
|
example_index_to_features = collections.defaultdict(list)
|
||||||
for feature in all_features:
|
for feature in all_features:
|
||||||
@@ -554,13 +558,15 @@ def compute_predictions_logits(
|
|||||||
all_predictions[example.qas_id] = best_non_null_entry.text
|
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||||
all_nbest_json[example.qas_id] = nbest_json
|
all_nbest_json[example.qas_id] = nbest_json
|
||||||
|
|
||||||
with open(output_prediction_file, "w") as writer:
|
if output_prediction_file:
|
||||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
with open(output_prediction_file, "w") as writer:
|
||||||
|
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||||
|
|
||||||
with open(output_nbest_file, "w") as writer:
|
if output_nbest_file:
|
||||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
with open(output_nbest_file, "w") as writer:
|
||||||
|
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||||
|
|
||||||
if version_2_with_negative:
|
if output_null_log_odds_file and version_2_with_negative:
|
||||||
with open(output_null_log_odds_file, "w") as writer:
|
with open(output_null_log_odds_file, "w") as writer:
|
||||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -251,6 +251,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
|
|||||||
start_position=start_position,
|
start_position=start_position,
|
||||||
end_position=end_position,
|
end_position=end_position,
|
||||||
is_impossible=span_is_impossible,
|
is_impossible=span_is_impossible,
|
||||||
|
qas_id=example.qas_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return features
|
return features
|
||||||
@@ -344,9 +345,9 @@ def squad_convert_examples_to_features(
|
|||||||
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
|
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
|
||||||
|
|
||||||
if not is_training:
|
if not is_training:
|
||||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||||
dataset = TensorDataset(
|
dataset = TensorDataset(
|
||||||
all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask
|
all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||||
@@ -368,12 +369,14 @@ def squad_convert_examples_to_features(
|
|||||||
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
|
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for ex in features:
|
for i, ex in enumerate(features):
|
||||||
yield (
|
yield (
|
||||||
{
|
{
|
||||||
"input_ids": ex.input_ids,
|
"input_ids": ex.input_ids,
|
||||||
"attention_mask": ex.attention_mask,
|
"attention_mask": ex.attention_mask,
|
||||||
"token_type_ids": ex.token_type_ids,
|
"token_type_ids": ex.token_type_ids,
|
||||||
|
"feature_index": i,
|
||||||
|
"qas_id": ex.qas_id,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"start_position": ex.start_position,
|
"start_position": ex.start_position,
|
||||||
@@ -384,35 +387,44 @@ def squad_convert_examples_to_features(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return tf.data.Dataset.from_generator(
|
# Why have we split the batch into a tuple? PyTorch just has a list of tensors.
|
||||||
gen,
|
train_types = (
|
||||||
(
|
{
|
||||||
{"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32},
|
"input_ids": tf.int32,
|
||||||
{
|
"attention_mask": tf.int32,
|
||||||
"start_position": tf.int64,
|
"token_type_ids": tf.int32,
|
||||||
"end_position": tf.int64,
|
"feature_index": tf.int64,
|
||||||
"cls_index": tf.int64,
|
"qas_id": tf.string,
|
||||||
"p_mask": tf.int32,
|
},
|
||||||
"is_impossible": tf.int32,
|
{
|
||||||
},
|
"start_position": tf.int64,
|
||||||
),
|
"end_position": tf.int64,
|
||||||
(
|
"cls_index": tf.int64,
|
||||||
{
|
"p_mask": tf.int32,
|
||||||
"input_ids": tf.TensorShape([None]),
|
"is_impossible": tf.int32,
|
||||||
"attention_mask": tf.TensorShape([None]),
|
},
|
||||||
"token_type_ids": tf.TensorShape([None]),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"start_position": tf.TensorShape([]),
|
|
||||||
"end_position": tf.TensorShape([]),
|
|
||||||
"cls_index": tf.TensorShape([]),
|
|
||||||
"p_mask": tf.TensorShape([None]),
|
|
||||||
"is_impossible": tf.TensorShape([]),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return features
|
train_shapes = (
|
||||||
|
{
|
||||||
|
"input_ids": tf.TensorShape([None]),
|
||||||
|
"attention_mask": tf.TensorShape([None]),
|
||||||
|
"token_type_ids": tf.TensorShape([None]),
|
||||||
|
"feature_index": tf.TensorShape([]),
|
||||||
|
"qas_id": tf.TensorShape([]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start_position": tf.TensorShape([]),
|
||||||
|
"end_position": tf.TensorShape([]),
|
||||||
|
"cls_index": tf.TensorShape([]),
|
||||||
|
"p_mask": tf.TensorShape([None]),
|
||||||
|
"is_impossible": tf.TensorShape([]),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
|
||||||
|
else:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
class SquadProcessor(DataProcessor):
|
class SquadProcessor(DataProcessor):
|
||||||
@@ -678,6 +690,7 @@ class SquadFeatures(object):
|
|||||||
start_position,
|
start_position,
|
||||||
end_position,
|
end_position,
|
||||||
is_impossible,
|
is_impossible,
|
||||||
|
qas_id: str = None,
|
||||||
):
|
):
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
self.attention_mask = attention_mask
|
self.attention_mask = attention_mask
|
||||||
@@ -695,6 +708,7 @@ class SquadFeatures(object):
|
|||||||
self.start_position = start_position
|
self.start_position = start_position
|
||||||
self.end_position = end_position
|
self.end_position = end_position
|
||||||
self.is_impossible = is_impossible
|
self.is_impossible = is_impossible
|
||||||
|
self.qas_id = qas_id
|
||||||
|
|
||||||
|
|
||||||
class SquadResult(object):
|
class SquadResult(object):
|
||||||
|
|||||||
Reference in New Issue
Block a user