Migrate metrics used in flax examples to Evaluate (#18348)
Currently, tensorflow examples use the `load_metric` function from Datasets library, commit migrates function call to `load` function from Evaluate library.
This commit is contained in:
@@ -31,10 +31,11 @@ from typing import Callable, Optional
|
||||
import datasets
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset, load_metric
|
||||
from datasets import Dataset, load_dataset
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import evaluate
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
@@ -811,7 +812,7 @@ def main():
|
||||
yield batch
|
||||
|
||||
# Metric
|
||||
metric = load_metric("rouge")
|
||||
metric = evaluate.load("rouge")
|
||||
|
||||
def postprocess_text(preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
|
||||
Reference in New Issue
Block a user