QDQBert example update (#16395)

* update Dockerfile and utils_qa

* Update README.md
This commit is contained in:
Shang Zhang
2022-03-28 02:47:52 -07:00
committed by GitHub
parent f6f6866e9e
commit 7ecbb9c5e4
3 changed files with 18 additions and 18 deletions

View File

@@ -73,10 +73,12 @@ def postprocess_qa_predictions(
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
``logging`` log level (e.g., ``logging.WARNING``)
"""
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
if len(predictions) != 2:
raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
all_start_logits, all_end_logits = predictions
assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
if len(predictions[0]) != len(features):
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
@@ -135,7 +137,9 @@ def postprocess_qa_predictions(
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
):
continue
# Don't consider answers with a length that is either < 0 or > max_answer_length.
@@ -145,6 +149,7 @@ def postprocess_qa_predictions(
# provided).
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
continue
prelim_predictions.append(
{
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
@@ -212,7 +217,8 @@ def postprocess_qa_predictions(
# If we have an output_dir, let's save all those dicts.
if output_dir is not None:
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
if not os.path.isdir(output_dir):
raise EnvironmentError(f"{output_dir} is not a directory.")
prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
@@ -283,12 +289,12 @@ def postprocess_qa_predictions_with_beam_search(
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
``logging`` log level (e.g., ``logging.WARNING``)
"""
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
if len(predictions) != 5:
raise ValueError("`predictions` should be a tuple with five elements.")
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
assert len(predictions[0]) == len(
features
), f"Got {len(predictions[0])} predicitions and {len(features)} features."
if len(predictions[0]) != len(features):
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
@@ -400,7 +406,8 @@ def postprocess_qa_predictions_with_beam_search(
# If we have an output_dir, let's save all those dicts.
if output_dir is not None:
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
if not os.path.isdir(output_dir):
raise EnvironmentError(f"{output_dir} is not a directory.")
prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"