Migrate metrics used in flax examples to Evaluate (#18348)

Currently, tensorflow examples use the `load_metric` function from
Datasets library, commit migrates function call to `load` function
from Evaluate library.
This commit is contained in:
Vijay S Kalmath
2022-07-28 15:06:23 -04:00
committed by GitHub
parent a2586795e5
commit da503ea02f
6 changed files with 18 additions and 12 deletions

View File

@@ -5,3 +5,4 @@ nltk
rouge-score rouge-score
seqeval seqeval
tensorboard tensorboard
evaluate >= 0.2.0

View File

@@ -31,10 +31,11 @@ from typing import Callable, Optional
import datasets import datasets
import nltk # Here to have a nice missing dependency error message early on import nltk # Here to have a nice missing dependency error message early on
import numpy as np import numpy as np
from datasets import Dataset, load_dataset, load_metric from datasets import Dataset, load_dataset
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import evaluate
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
@@ -811,7 +812,7 @@ def main():
yield batch yield batch
# Metric # Metric
metric = load_metric("rouge") metric = evaluate.load("rouge")
def postprocess_text(preds, labels): def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds] preds = [pred.strip() for pred in preds]

View File

@@ -32,9 +32,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
import datasets import datasets
import numpy as np import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset
from tqdm import tqdm from tqdm import tqdm
import evaluate
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
@@ -776,7 +777,7 @@ def main():
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references) return EvalPrediction(predictions=formatted_predictions, label_ids=references)
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction): def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids) return metric.compute(predictions=p.predictions, references=p.label_ids)

View File

@@ -33,9 +33,10 @@ from typing import Callable, Optional
import datasets import datasets
import nltk # Here to have a nice missing dependency error message early on import nltk # Here to have a nice missing dependency error message early on
import numpy as np import numpy as np
from datasets import Dataset, load_dataset, load_metric from datasets import Dataset, load_dataset
from tqdm import tqdm from tqdm import tqdm
import evaluate
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
@@ -656,7 +657,7 @@ def main():
) )
# Metric # Metric
metric = load_metric("rouge") metric = evaluate.load("rouge")
def postprocess_text(preds, labels): def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds] preds = [pred.strip() for pred in preds]

View File

@@ -27,9 +27,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
import datasets import datasets
import numpy as np import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset
from tqdm import tqdm from tqdm import tqdm
import evaluate
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
@@ -570,9 +571,9 @@ def main():
p_eval_step = jax.pmap(eval_step, axis_name="batch") p_eval_step = jax.pmap(eval_step, axis_name="batch")
if data_args.task_name is not None: if data_args.task_name is not None:
metric = load_metric("glue", data_args.task_name) metric = evaluate.load("glue", data_args.task_name)
else: else:
metric = load_metric("accuracy") metric = evaluate.load("accuracy")
logger.info(f"===== Starting training ({num_epochs} epochs) =====") logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0 train_time = 0

View File

@@ -29,9 +29,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
import datasets import datasets
import numpy as np import numpy as np
from datasets import ClassLabel, load_dataset, load_metric from datasets import ClassLabel, load_dataset
from tqdm import tqdm from tqdm import tqdm
import evaluate
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
@@ -646,7 +647,7 @@ def main():
p_eval_step = jax.pmap(eval_step, axis_name="batch") p_eval_step = jax.pmap(eval_step, axis_name="batch")
metric = load_metric("seqeval") metric = evaluate.load("seqeval")
def get_labels(y_pred, y_true): def get_labels(y_pred, y_true):
# Transform predictions and references tensos to numpy arrays # Transform predictions and references tensos to numpy arrays