Migrate metric to Evaluate in Pytorch examples (#18369)
* Migrate metric to Evaluate in pytorch examples * Remove unused imports
This commit is contained in:
@@ -19,7 +19,6 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
@@ -34,6 +33,7 @@ from torchvision.transforms import (
|
||||
ToTensor,
|
||||
)
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
@@ -252,7 +252,7 @@ def main():
|
||||
id2label[str(i)] = label
|
||||
|
||||
# Load the accuracy metric from the datasets package
|
||||
metric = datasets.load_metric("accuracy")
|
||||
metric = evaluate.load("accuracy")
|
||||
|
||||
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
||||
# predictions and label_ids field) and has to return a dictionary string to float.
|
||||
|
||||
@@ -22,7 +22,7 @@ from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
@@ -35,6 +35,7 @@ from torchvision.transforms import (
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -415,7 +416,7 @@ def main():
|
||||
accelerator.init_trackers("image_classification_no_trainer", experiment_config)
|
||||
|
||||
# Get the metric function
|
||||
metric = load_metric("accuracy")
|
||||
metric = evaluate.load("accuracy")
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
Reference in New Issue
Block a user