Migrate metric to Evaluate in Pytorch examples (#18369)
* Migrate metric to Evaluate in pytorch examples * Remove unused imports
This commit is contained in:
@@ -27,8 +27,9 @@ from typing import Optional
|
||||
import datasets
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric
|
||||
from datasets import load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from filelock import FileLock
|
||||
from transformers import (
|
||||
@@ -598,7 +599,7 @@ def main():
|
||||
)
|
||||
|
||||
# Metric
|
||||
metric = load_metric("rouge")
|
||||
metric = evaluate.load("rouge")
|
||||
|
||||
def postprocess_text(preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
|
||||
@@ -30,10 +30,11 @@ import datasets
|
||||
import nltk
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -583,7 +584,7 @@ def main():
|
||||
accelerator.init_trackers("summarization_no_trainer", experiment_config)
|
||||
|
||||
# Metric
|
||||
metric = load_metric("rouge")
|
||||
metric = evaluate.load("rouge")
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
Reference in New Issue
Block a user