Add counters for dataset classes (#37636)

* add counters for dataset classes

* fix failed code style
This commit is contained in:
Ken J
2025-04-22 09:30:43 -07:00
committed by GitHub
parent d47cdae27e
commit ca4c114dc4

View File

@@ -19,6 +19,7 @@ import logging
import os import os
import random import random
import sys import sys
from collections import Counter
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
@@ -467,6 +468,14 @@ def main():
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )
def print_class_distribution(dataset, split_name):
label_counts = Counter(dataset["label"])
total = sum(label_counts.values())
logger.info(f"Class distribution in {split_name} set:")
for label, count in label_counts.items():
logger.info(f" Label {label}: {count} ({count / total:.2%})")
if training_args.do_train: if training_args.do_train:
if "train" not in raw_datasets: if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
@@ -474,6 +483,7 @@ def main():
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples) max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples)) train_dataset = train_dataset.select(range(max_train_samples))
print_class_distribution(train_dataset, "train")
if training_args.do_eval: if training_args.do_eval:
if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
@@ -482,6 +492,7 @@ def main():
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples)) eval_dataset = eval_dataset.select(range(max_eval_samples))
print_class_distribution(eval_dataset, "validation")
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
if "test" not in raw_datasets and "test_matched" not in raw_datasets: if "test" not in raw_datasets and "test_matched" not in raw_datasets:
@@ -490,6 +501,7 @@ def main():
if data_args.max_predict_samples is not None: if data_args.max_predict_samples is not None:
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples)) predict_dataset = predict_dataset.select(range(max_predict_samples))
print_class_distribution(predict_dataset, "test")
# Log a few random samples from the training set: # Log a few random samples from the training set:
if training_args.do_train: if training_args.do_train: