Migrate metric to Evaluate library for tensorflow examples (#18327)
* Migrate metric to Evaluate library in tf examples Currently tensorflow examples use `load_metric` function from Datasets library , commit migrates function call to `load` function to Evaluate library. Fix for #18306 * Migrate metric to Evaluate library in tf examples Currently tensorflow examples use `load_metric` function from Datasets library , commit migrates function call to `load` function to Evaluate library. Fix for #18306 * Migrate `metric` to Evaluate for all tf examples Currently tensorflow examples use `load_metric` function from Datasets library , commit migrates function call to `load` function to Evaluate library.
This commit is contained in:
@@ -1,2 +1,3 @@
|
|||||||
datasets >= 1.4.0
|
datasets >= 1.4.0
|
||||||
tensorflow >= 2.3.0
|
tensorflow >= 2.3.0
|
||||||
|
evaluate >= 0.2.0
|
||||||
@@ -26,8 +26,9 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@@ -600,7 +601,7 @@ def main():
|
|||||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
||||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||||
|
|
||||||
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
|
metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
|
||||||
|
|
||||||
def compute_metrics(p: EvalPrediction):
|
def compute_metrics(p: EvalPrediction):
|
||||||
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
||||||
|
|||||||
3
examples/tensorflow/summarization/requirements.txt
Normal file
3
examples/tensorflow/summarization/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
datasets >= 1.4.0
|
||||||
|
tensorflow >= 2.3.0
|
||||||
|
evaluate >= 0.2.0
|
||||||
@@ -29,9 +29,10 @@ import datasets
|
|||||||
import nltk # Here to have a nice missing dependency error message early on
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import transformers
|
import transformers
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -634,7 +635,7 @@ def main():
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Metric
|
# region Metric
|
||||||
metric = load_metric("rouge")
|
metric = evaluate.load("rouge")
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Training
|
# region Training
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ datasets >= 1.1.3
|
|||||||
sentencepiece != 0.1.92
|
sentencepiece != 0.1.92
|
||||||
protobuf
|
protobuf
|
||||||
tensorflow >= 2.3
|
tensorflow >= 2.3
|
||||||
|
evaluate >= 0.2.0
|
||||||
@@ -24,8 +24,9 @@ from typing import Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@@ -366,7 +367,7 @@ def main():
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Metric function
|
# region Metric function
|
||||||
metric = load_metric("glue", data_args.task_name)
|
metric = evaluate.load("glue", data_args.task_name)
|
||||||
|
|
||||||
def compute_metrics(preds, label_ids):
|
def compute_metrics(preds, label_ids):
|
||||||
preds = preds["logits"]
|
preds = preds["logits"]
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
datasets >= 1.4.0
|
||||||
|
tensorflow >= 2.3.0
|
||||||
|
evaluate >= 0.2.0
|
||||||
@@ -27,8 +27,9 @@ from typing import Optional
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from datasets import ClassLabel, load_dataset, load_metric
|
from datasets import ClassLabel, load_dataset
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
@@ -478,7 +479,7 @@ def main():
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
metric = load_metric("seqeval")
|
metric = evaluate.load("seqeval")
|
||||||
|
|
||||||
def get_labels(y_pred, y_true):
|
def get_labels(y_pred, y_true):
|
||||||
# Transform predictions and references tensos to numpy arrays
|
# Transform predictions and references tensos to numpy arrays
|
||||||
|
|||||||
3
examples/tensorflow/translation/requirements.txt
Normal file
3
examples/tensorflow/translation/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
datasets >= 1.4.0
|
||||||
|
tensorflow >= 2.3.0
|
||||||
|
evaluate >= 0.2.0
|
||||||
@@ -28,9 +28,10 @@ from typing import Optional
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@@ -590,7 +591,7 @@ def main():
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Metric and postprocessing
|
# region Metric and postprocessing
|
||||||
metric = load_metric("sacrebleu")
|
metric = evaluate.load("sacrebleu")
|
||||||
|
|
||||||
def postprocess_text(preds, labels):
|
def postprocess_text(preds, labels):
|
||||||
preds = [pred.strip() for pred in preds]
|
preds = [pred.strip() for pred in preds]
|
||||||
|
|||||||
Reference in New Issue
Block a user