Migrate metrics used in flax examples to Evaluate (#18348)

Currently, tensorflow examples use the `load_metric` function from
Datasets library, commit migrates function call to `load` function
from Evaluate library.
This commit is contained in:
Vijay S Kalmath
2022-07-28 15:06:23 -04:00
committed by GitHub
parent a2586795e5
commit da503ea02f
6 changed files with 18 additions and 12 deletions

View File

@@ -27,9 +27,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset, load_metric
from datasets import load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import optax
@@ -570,9 +571,9 @@ def main():
p_eval_step = jax.pmap(eval_step, axis_name="batch")
if data_args.task_name is not None:
metric = load_metric("glue", data_args.task_name)
metric = evaluate.load("glue", data_args.task_name)
else:
metric = load_metric("accuracy")
metric = evaluate.load("accuracy")
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0