[example/glue] fix compute_metrics_fn for bart like models (#7248)
* fix compute_metrics_fn * p.predictions -> preds * apply suggestions
This commit is contained in:
@@ -150,10 +150,11 @@ def main():
|
|||||||
|
|
||||||
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
||||||
def compute_metrics_fn(p: EvalPrediction):
|
def compute_metrics_fn(p: EvalPrediction):
|
||||||
|
preds = p.predictions[0] if type(p.predictions) == tuple else p.predictions
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
preds = np.argmax(p.predictions, axis=1)
|
preds = np.argmax(preds, axis=1)
|
||||||
elif output_mode == "regression":
|
else: # regression
|
||||||
preds = np.squeeze(p.predictions)
|
preds = np.squeeze(preds)
|
||||||
return glue_compute_metrics(task_name, preds, p.label_ids)
|
return glue_compute_metrics(task_name, preds, p.label_ids)
|
||||||
|
|
||||||
return compute_metrics_fn
|
return compute_metrics_fn
|
||||||
|
|||||||
Reference in New Issue
Block a user