Migrate metrics used in flax examples to Evaluate (#18348)
Currently, tensorflow examples use the `load_metric` function from Datasets library, commit migrates function call to `load` function from Evaluate library.
This commit is contained in:
@@ -27,9 +27,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import evaluate
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
@@ -570,9 +571,9 @@ def main():
|
||||
p_eval_step = jax.pmap(eval_step, axis_name="batch")
|
||||
|
||||
if data_args.task_name is not None:
|
||||
metric = load_metric("glue", data_args.task_name)
|
||||
metric = evaluate.load("glue", data_args.task_name)
|
||||
else:
|
||||
metric = load_metric("accuracy")
|
||||
metric = evaluate.load("accuracy")
|
||||
|
||||
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
|
||||
train_time = 0
|
||||
|
||||
Reference in New Issue
Block a user