diff --git a/examples/flax/_tests_requirements.txt b/examples/flax/_tests_requirements.txt index f9de455f62..f1e0fb2d90 100644 --- a/examples/flax/_tests_requirements.txt +++ b/examples/flax/_tests_requirements.txt @@ -4,4 +4,5 @@ conllu nltk rouge-score seqeval -tensorboard \ No newline at end of file +tensorboard +evaluate >= 0.2.0 \ No newline at end of file diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py index a4deab8041..4fe144db8b 100644 --- a/examples/flax/image-captioning/run_image_captioning_flax.py +++ b/examples/flax/image-captioning/run_image_captioning_flax.py @@ -31,10 +31,11 @@ from typing import Callable, Optional import datasets import nltk # Here to have a nice missing dependency error message early on import numpy as np -from datasets import Dataset, load_dataset, load_metric +from datasets import Dataset, load_dataset from PIL import Image from tqdm import tqdm +import evaluate import jax import jax.numpy as jnp import optax @@ -811,7 +812,7 @@ def main(): yield batch # Metric - metric = load_metric("rouge") + metric = evaluate.load("rouge") def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index bc3b5acc50..0873b19413 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -32,9 +32,10 @@ from typing import Any, Callable, Dict, Optional, Tuple import datasets import numpy as np -from datasets import load_dataset, load_metric +from datasets import load_dataset from tqdm import tqdm +import evaluate import jax import jax.numpy as jnp import optax @@ -776,7 +777,7 @@ def main(): references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) - metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") + metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad") def compute_metrics(p: EvalPrediction): return metric.compute(predictions=p.predictions, references=p.label_ids) diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index bd17141a44..d6f8ec78ba 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -33,9 +33,10 @@ from typing import Callable, Optional import datasets import nltk # Here to have a nice missing dependency error message early on import numpy as np -from datasets import Dataset, load_dataset, load_metric +from datasets import Dataset, load_dataset from tqdm import tqdm +import evaluate import jax import jax.numpy as jnp import optax @@ -656,7 +657,7 @@ def main(): ) # Metric - metric = load_metric("rouge") + metric = evaluate.load("rouge") def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index 84e1c85125..7f5524dbb4 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -27,9 +27,10 @@ from typing import Any, Callable, Dict, Optional, Tuple import datasets import numpy as np -from datasets import load_dataset, load_metric +from datasets import load_dataset from tqdm import tqdm +import evaluate import jax import jax.numpy as jnp import optax @@ -570,9 +571,9 @@ def main(): p_eval_step = jax.pmap(eval_step, axis_name="batch") if data_args.task_name is not None: - metric = load_metric("glue", data_args.task_name) + metric = evaluate.load("glue", data_args.task_name) else: - metric = load_metric("accuracy") + metric = evaluate.load("accuracy") logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index 79e1589e3f..0a66b5f199 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -29,9 +29,10 @@ from typing import Any, Callable, Dict, Optional, Tuple import datasets import numpy as np -from datasets import ClassLabel, load_dataset, load_metric +from datasets import ClassLabel, load_dataset from tqdm import tqdm +import evaluate import jax import jax.numpy as jnp import optax @@ -646,7 +647,7 @@ def main(): p_eval_step = jax.pmap(eval_step, axis_name="batch") - metric = load_metric("seqeval") + metric = evaluate.load("seqeval") def get_labels(y_pred, y_true): # Transform predictions and references tensos to numpy arrays