Migrate metric to Evaluate in Pytorch examples (#18369)
* Migrate metric to Evaluate in pytorch examples * Remove unused imports
This commit is contained in:
@@ -21,7 +21,6 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
@@ -30,6 +29,7 @@ from torch import nn
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import (
|
||||
@@ -337,7 +337,7 @@ def main():
|
||||
label2id = {v: str(k) for k, v in id2label.items()}
|
||||
|
||||
# Load the mean IoU metric from the datasets package
|
||||
metric = datasets.load_metric("mean_iou")
|
||||
metric = evaluate.load("mean_iou")
|
||||
|
||||
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
||||
# predictions and label_ids field) and has to return a dictionary string to float.
|
||||
|
||||
@@ -24,13 +24,14 @@ from pathlib import Path
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -500,7 +501,7 @@ def main():
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Instantiate metric
|
||||
metric = load_metric("mean_iou")
|
||||
metric = evaluate.load("mean_iou")
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# We initialize the trackers only on main process because `accelerator.log`
|
||||
|
||||
Reference in New Issue
Block a user