Add counters for dataset classes (#37636)

* add counters for dataset classes

* fix failed code style
This commit is contained in:
Ken J
2025-04-22 09:30:43 -07:00
committed by GitHub
parent d47cdae27e
commit ca4c114dc4

View File

@@ -19,6 +19,7 @@ import logging
import os
import random
import sys
from collections import Counter
from dataclasses import dataclass, field
from typing import Optional
@@ -467,6 +468,14 @@ def main():
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
def print_class_distribution(dataset, split_name):
label_counts = Counter(dataset["label"])
total = sum(label_counts.values())
logger.info(f"Class distribution in {split_name} set:")
for label, count in label_counts.items():
logger.info(f" Label {label}: {count} ({count / total:.2%})")
if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
@@ -474,6 +483,7 @@ def main():
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
print_class_distribution(train_dataset, "train")
if training_args.do_eval:
if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
@@ -482,6 +492,7 @@ def main():
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
print_class_distribution(eval_dataset, "validation")
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
if "test" not in raw_datasets and "test_matched" not in raw_datasets:
@@ -490,6 +501,7 @@ def main():
if data_args.max_predict_samples is not None:
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
print_class_distribution(predict_dataset, "test")
# Log a few random samples from the training set:
if training_args.do_train: