QDQBert example update (#16395)
* update Dockerfile and utils_qa * Update README.md
This commit is contained in:
@@ -73,10 +73,12 @@ def postprocess_qa_predictions(
|
||||
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
|
||||
``logging`` log level (e.g., ``logging.WARNING``)
|
||||
"""
|
||||
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
|
||||
if len(predictions) != 2:
|
||||
raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
|
||||
all_start_logits, all_end_logits = predictions
|
||||
|
||||
assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
|
||||
if len(predictions[0]) != len(features):
|
||||
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
||||
@@ -135,7 +137,9 @@ def postprocess_qa_predictions(
|
||||
start_index >= len(offset_mapping)
|
||||
or end_index >= len(offset_mapping)
|
||||
or offset_mapping[start_index] is None
|
||||
or len(offset_mapping[start_index]) < 2
|
||||
or offset_mapping[end_index] is None
|
||||
or len(offset_mapping[end_index]) < 2
|
||||
):
|
||||
continue
|
||||
# Don't consider answers with a length that is either < 0 or > max_answer_length.
|
||||
@@ -145,6 +149,7 @@ def postprocess_qa_predictions(
|
||||
# provided).
|
||||
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
|
||||
continue
|
||||
|
||||
prelim_predictions.append(
|
||||
{
|
||||
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
|
||||
@@ -212,7 +217,8 @@ def postprocess_qa_predictions(
|
||||
|
||||
# If we have an output_dir, let's save all those dicts.
|
||||
if output_dir is not None:
|
||||
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
||||
if not os.path.isdir(output_dir):
|
||||
raise EnvironmentError(f"{output_dir} is not a directory.")
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
@@ -283,12 +289,12 @@ def postprocess_qa_predictions_with_beam_search(
|
||||
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
|
||||
``logging`` log level (e.g., ``logging.WARNING``)
|
||||
"""
|
||||
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
|
||||
if len(predictions) != 5:
|
||||
raise ValueError("`predictions` should be a tuple with five elements.")
|
||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
|
||||
|
||||
assert len(predictions[0]) == len(
|
||||
features
|
||||
), f"Got {len(predictions[0])} predicitions and {len(features)} features."
|
||||
if len(predictions[0]) != len(features):
|
||||
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
||||
@@ -400,7 +406,8 @@ def postprocess_qa_predictions_with_beam_search(
|
||||
|
||||
# If we have an output_dir, let's save all those dicts.
|
||||
if output_dir is not None:
|
||||
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
||||
if not os.path.isdir(output_dir):
|
||||
raise EnvironmentError(f"{output_dir} is not a directory.")
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
|
||||
Reference in New Issue
Block a user