Migrate metrics used in flax examples to Evaluate (#18348)
Currently, tensorflow examples use the `load_metric` function from Datasets library, commit migrates function call to `load` function from Evaluate library.
This commit is contained in:
@@ -4,4 +4,5 @@ conllu
|
|||||||
nltk
|
nltk
|
||||||
rouge-score
|
rouge-score
|
||||||
seqeval
|
seqeval
|
||||||
tensorboard
|
tensorboard
|
||||||
|
evaluate >= 0.2.0
|
||||||
@@ -31,10 +31,11 @@ from typing import Callable, Optional
|
|||||||
import datasets
|
import datasets
|
||||||
import nltk # Here to have a nice missing dependency error message early on
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import Dataset, load_dataset, load_metric
|
from datasets import Dataset, load_dataset
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import optax
|
import optax
|
||||||
@@ -811,7 +812,7 @@ def main():
|
|||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
# Metric
|
# Metric
|
||||||
metric = load_metric("rouge")
|
metric = evaluate.load("rouge")
|
||||||
|
|
||||||
def postprocess_text(preds, labels):
|
def postprocess_text(preds, labels):
|
||||||
preds = [pred.strip() for pred in preds]
|
preds = [pred.strip() for pred in preds]
|
||||||
|
|||||||
@@ -32,9 +32,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import optax
|
import optax
|
||||||
@@ -776,7 +777,7 @@ def main():
|
|||||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
||||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||||
|
|
||||||
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
|
metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
|
||||||
|
|
||||||
def compute_metrics(p: EvalPrediction):
|
def compute_metrics(p: EvalPrediction):
|
||||||
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
||||||
|
|||||||
@@ -33,9 +33,10 @@ from typing import Callable, Optional
|
|||||||
import datasets
|
import datasets
|
||||||
import nltk # Here to have a nice missing dependency error message early on
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import Dataset, load_dataset, load_metric
|
from datasets import Dataset, load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import optax
|
import optax
|
||||||
@@ -656,7 +657,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Metric
|
# Metric
|
||||||
metric = load_metric("rouge")
|
metric = evaluate.load("rouge")
|
||||||
|
|
||||||
def postprocess_text(preds, labels):
|
def postprocess_text(preds, labels):
|
||||||
preds = [pred.strip() for pred in preds]
|
preds = [pred.strip() for pred in preds]
|
||||||
|
|||||||
@@ -27,9 +27,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import optax
|
import optax
|
||||||
@@ -570,9 +571,9 @@ def main():
|
|||||||
p_eval_step = jax.pmap(eval_step, axis_name="batch")
|
p_eval_step = jax.pmap(eval_step, axis_name="batch")
|
||||||
|
|
||||||
if data_args.task_name is not None:
|
if data_args.task_name is not None:
|
||||||
metric = load_metric("glue", data_args.task_name)
|
metric = evaluate.load("glue", data_args.task_name)
|
||||||
else:
|
else:
|
||||||
metric = load_metric("accuracy")
|
metric = evaluate.load("accuracy")
|
||||||
|
|
||||||
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
|
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
|
||||||
train_time = 0
|
train_time = 0
|
||||||
|
|||||||
@@ -29,9 +29,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import ClassLabel, load_dataset, load_metric
|
from datasets import ClassLabel, load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import optax
|
import optax
|
||||||
@@ -646,7 +647,7 @@ def main():
|
|||||||
|
|
||||||
p_eval_step = jax.pmap(eval_step, axis_name="batch")
|
p_eval_step = jax.pmap(eval_step, axis_name="batch")
|
||||||
|
|
||||||
metric = load_metric("seqeval")
|
metric = evaluate.load("seqeval")
|
||||||
|
|
||||||
def get_labels(y_pred, y_true):
|
def get_labels(y_pred, y_true):
|
||||||
# Transform predictions and references tensos to numpy arrays
|
# Transform predictions and references tensos to numpy arrays
|
||||||
|
|||||||
Reference in New Issue
Block a user