Add counters for dataset classes (#37636)
* add counters for dataset classes * fix failed code style
This commit is contained in:
@@ -19,6 +19,7 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@@ -467,6 +468,14 @@ def main():
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
def print_class_distribution(dataset, split_name):
|
||||
label_counts = Counter(dataset["label"])
|
||||
total = sum(label_counts.values())
|
||||
logger.info(f"Class distribution in {split_name} set:")
|
||||
for label, count in label_counts.items():
|
||||
logger.info(f" Label {label}: {count} ({count / total:.2%})")
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
@@ -474,6 +483,7 @@ def main():
|
||||
if data_args.max_train_samples is not None:
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
train_dataset = train_dataset.select(range(max_train_samples))
|
||||
print_class_distribution(train_dataset, "train")
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
|
||||
@@ -482,6 +492,7 @@ def main():
|
||||
if data_args.max_eval_samples is not None:
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||
print_class_distribution(eval_dataset, "validation")
|
||||
|
||||
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
||||
if "test" not in raw_datasets and "test_matched" not in raw_datasets:
|
||||
@@ -490,6 +501,7 @@ def main():
|
||||
if data_args.max_predict_samples is not None:
|
||||
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||
predict_dataset = predict_dataset.select(range(max_predict_samples))
|
||||
print_class_distribution(predict_dataset, "test")
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
if training_args.do_train:
|
||||
|
||||
Reference in New Issue
Block a user