Addressing review comment

This commit is contained in:
Lysandre
2020-09-21 11:37:00 +02:00
parent 43b9d93875
commit aae4edb5f0

View File

@@ -150,7 +150,7 @@ def main():
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def compute_metrics_fn(p: EvalPrediction):
preds = p.predictions[0] if type(p.predictions) == tuple else p.predictions
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
if output_mode == "classification":
preds = np.argmax(preds, axis=1)
else: # regression