diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index e152cd9911..7f8a266b44 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -19,6 +19,7 @@ import logging import os import random import sys +from collections import Counter from dataclasses import dataclass, field from typing import Optional @@ -467,6 +468,14 @@ def main(): load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on dataset", ) + + def print_class_distribution(dataset, split_name): + label_counts = Counter(dataset["label"]) + total = sum(label_counts.values()) + logger.info(f"Class distribution in {split_name} set:") + for label, count in label_counts.items(): + logger.info(f" Label {label}: {count} ({count / total:.2%})") + if training_args.do_train: if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") @@ -474,6 +483,7 @@ def main(): if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) + print_class_distribution(train_dataset, "train") if training_args.do_eval: if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: @@ -482,6 +492,7 @@ def main(): if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) + print_class_distribution(eval_dataset, "validation") if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: if "test" not in raw_datasets and "test_matched" not in raw_datasets: @@ -490,6 +501,7 @@ def main(): if data_args.max_predict_samples is not None: max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) predict_dataset = predict_dataset.select(range(max_predict_samples)) + print_class_distribution(predict_dataset, "test") # Log a few random samples from the training set: if training_args.do_train: