Addressing review comment
This commit is contained in:
@@ -150,7 +150,7 @@ def main():
|
||||
|
||||
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
||||
def compute_metrics_fn(p: EvalPrediction):
|
||||
preds = p.predictions[0] if type(p.predictions) == tuple else p.predictions
|
||||
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
|
||||
if output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
else: # regression
|
||||
|
||||
Reference in New Issue
Block a user