Add plots and classification results using lib

This commit is contained in:
ssum21
2026-01-09 21:32:38 +09:00
parent 02245cd946
commit 164e60349d
2 changed files with 115 additions and 356 deletions

View File

@@ -37,7 +37,7 @@ Date: 2026-01-09
"""
import os
import argparse
from glob import glob
from typing import Dict, Any, List, Tuple
import numpy as np
@@ -45,6 +45,7 @@ import matplotlib.pyplot as plt
import torch
from datasets import load_from_disk, Dataset
from chronos import BaseChronosPipeline, Chronos2Pipeline
from fire import Fire
from sklearn.manifold import TSNE
# pipeline: Chronos2Pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cuda")
@@ -125,24 +126,27 @@ class Chronos_2_Embedder:
@staticmethod
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
"""
Discover all user/session directories under data_root.
Uses glob pattern matching for cleaner directory traversal.
Expected structure: data_root/user_id/session_id/
Returns:
List of (user_id, session_id, session_path) tuples
"""
discovered_paths = []
# Iterate through user directories
for user_id in sorted(os.listdir(data_root)):
user_path = os.path.join(data_root, user_id)
# Skip non-directory entries
if not os.path.isdir(user_path):
# Use glob to find all session directories (2 levels deep)
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
if not os.path.isdir(session_path):
continue
# Iterate through session directories within each user
for session_id in sorted(os.listdir(user_path)):
session_path = os.path.join(user_path, session_id)
if not os.path.isdir(session_path):
continue
discovered_paths.append((user_id, session_id, session_path))
# Extract user_id and session_id from path
session_id = os.path.basename(session_path)
user_id = os.path.basename(os.path.dirname(session_path))
discovered_paths.append((user_id, session_id, session_path))
return discovered_paths
@@ -284,15 +288,6 @@ class Chronos_2_Embedder:
dataset: Dataset,
output_dir: str
) -> None:
"""
Save embeddings dataset to disk using HuggingFace format.
Uses Arrow format for efficient storage and memory-mapped loading.
Args:
dataset: HuggingFace Dataset with embeddings and metadata
output_dir: Directory to save the dataset (created if not exists)
"""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
@@ -302,15 +297,6 @@ class Chronos_2_Embedder:
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load previously saved embeddings dataset.
Args:
embedding_dir: Directory containing the HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
@@ -409,131 +395,79 @@ def create_scatter_plot(
# Command Line Interface
# =============================================================================
def parse_arguments() -> argparse.Namespace:
"""
Parse command line arguments.
Supports two subcommands:
- extract: Generate embeddings from raw data
- plot: Visualize existing embeddings
"""
parser = argparse.ArgumentParser(
description="Chronos-2 Time Series Embedding Pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
subparsers = parser.add_subparsers(dest="command", required=True)
extract_parser = subparsers.add_parser(
"extract",
help="Extract embeddings from time series data"
)
extract_parser.add_argument(
"--data_root", type=str, required=True,
help="Root directory containing user/session data folders"
)
extract_parser.add_argument(
"--out_dir", type=str, required=True,
help="Output directory for embeddings.npy and metadata.csv"
)
extract_parser.add_argument(
"--model", type=str, default="amazon/chronos-2",
help="Chronos-2 model name or path (default: amazon/chronos-2)"
)
extract_parser.add_argument(
"--batch_size", type=int, default=32,
help="Batch size for inference (default: 32)"
)
extract_parser.add_argument(
"--pooling", type=str, default="mean",
choices=["mean", "cls"],
help="Pooling strategy: 'mean' or 'cls' (default: mean)"
)
plot_parser = subparsers.add_parser(
"plot",
help="Visualize embeddings with t-SNE"
)
plot_parser.add_argument(
"--emb_dir", type=str, required=True,
help="Directory containing embeddings.npy and metadata.csv"
)
plot_parser.add_argument(
"--out_dir", type=str, required=True,
help="Output directory for visualization plots"
)
plot_parser.add_argument(
"--perplexity", type=float, default=30.0,
help="t-SNE perplexity parameter (default: 30.0)"
)
plot_parser.add_argument(
"--users", type=str, default=None,
help="Comma-separated user IDs to include (e.g., '00,01,02')"
)
plot_parser.add_argument(
"--num_users", type=int, default=0,
help="Include only first N users (0 = all users)"
)
plot_parser.add_argument(
"--labels", type=str, default=None,
help="Comma-separated sleep stage labels to include (e.g., '0,1,2')"
)
return parser.parse_args()
def main():
"""
Main entry point for the embedding pipeline.
Dispatches to either embedding extraction or visualization
based on the subcommand.
"""
args = parse_arguments()
# =========================================================================
# Extract Command: Generate embeddings from raw data
# =========================================================================
if args.command == "extract":
# Initialize embedder with specified configuration
class CLI:
def extract(
self,
data_root: str,
out_dir: str,
model: str = "amazon/chronos-2",
batch_size: int = 32,
pooling: str = "mean",
) -> None:
"""
Extract embeddings from time series data.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for HuggingFace dataset
model: Chronos-2 model name or path (default: amazon/chronos-2)
batch_size: Batch size for inference (default: 32)
pooling: Pooling strategy - 'mean' or 'cls' (default: mean)
"""
# Validate pooling argument
if pooling not in ["mean", "cls"]:
raise ValueError(f"Invalid pooling strategy: {pooling}. Use 'mean' or 'cls'.")
# Initialize embedder
embedder = Chronos_2_Embedder(
model_name=args.model,
pooling_strategy=args.pooling,
model_name=model,
pooling_strategy=pooling,
)
# Extract embeddings from all data
dataset = embedder.extract_embeddings(
args.data_root,
args.batch_size
)
# Save results
embedder.save_embeddings(dataset, args.out_dir)
# Extract and save embeddings
dataset = embedder.extract_embeddings(data_root, batch_size)
embedder.save_embeddings(dataset, out_dir)
# =========================================================================
# Plot Command: Visualize existing embeddings
# =========================================================================
elif args.command == "plot":
os.makedirs(args.out_dir, exist_ok=True)
def plot(
self,
emb_dir: str,
out_dir: str,
perplexity: float = 30.0,
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE.
Args:
emb_dir: Directory containing the HuggingFace embeddings dataset
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 30.0)
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., '0,1,2')
"""
os.makedirs(out_dir, exist_ok=True)
# Load saved embeddings dataset
dataset = Chronos_2_Embedder.load_embeddings(args.emb_dir)
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
# Apply user filtering if specified
if args.users:
user_list = [u.strip() for u in args.users.split(",")]
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif args.num_users > 0:
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:args.num_users]
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {args.num_users} users: {selected_users}")
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels if specified
if args.labels:
label_list = [l.strip() for l in args.labels.split(",")]
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
@@ -543,25 +477,23 @@ def main():
embeddings = np.array(dataset["embedding"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, args.perplexity)
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations
# Plot 1: Color by sleep stage label
create_scatter_plot(
coordinates_2d,
np.array(dataset["label"]),
"t-SNE Visualization (Colored by Sleep Stage)",
os.path.join(args.out_dir, "tsne_by_label.pdf")
os.path.join(out_dir, "tsne_by_label.pdf")
)
# Plot 2: Color by user ID
create_scatter_plot(
coordinates_2d,
np.array(dataset["user_id"]),
"t-SNE Visualization (Colored by User ID)",
os.path.join(args.out_dir, "tsne_by_user.pdf")
os.path.join(out_dir, "tsne_by_user.pdf")
)
if __name__ == "__main__":
main()
Fire(CLI)

View File

@@ -43,7 +43,6 @@ Date: 2026-01-09
"""
import os
import argparse
from typing import Tuple
import numpy as np
@@ -51,6 +50,7 @@ import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset
from fire import Fire
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
@@ -69,22 +69,6 @@ from sklearn.model_selection import train_test_split
# =============================================================================
def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
"""
Load embeddings dataset from disk using HuggingFace format.
Expects a HuggingFace Dataset directory with columns:
- user_id, session_id, idx, label (metadata)
- embedding (1024-dim vector)
Args:
embedding_path: Path to the HuggingFace dataset directory
Returns:
HuggingFace Dataset with embeddings and metadata
Raises:
FileNotFoundError: If dataset directory is not found
"""
if not os.path.isdir(embedding_path):
raise FileNotFoundError(
f"Dataset directory not found: {embedding_path}. "
@@ -96,7 +80,6 @@ def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
return dataset
# =============================================================================
# Data Splitting
# =============================================================================
@@ -250,7 +233,7 @@ def evaluate_classifier(
Metrics Computed:
- Accuracy: Overall fraction of correct predictions
- Macro F1: Average F1 across all classes (treats all classes equally)
- Macro F1: Average F1 across all classes
- Per-class report: Precision, recall, F1 for each user
- Confusion matrix: Detailed breakdown of predictions vs ground truth
@@ -268,8 +251,7 @@ def evaluate_classifier(
# Generate detailed per-class report
report = classification_report(y_true, y_pred, digits=4)
# Compute confusion matrix
# Get all unique classes from both true and predicted labels
# Compute confusion matrix: Get all unique classes from both true and predicted labels
class_labels = sorted(set(y_true) | set(y_pred))
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
@@ -314,14 +296,6 @@ def save_confusion_matrix_plot(
ax.set_xlabel("Predicted User")
ax.set_ylabel("True User")
# Note: Axis ticks omitted for readability when many classes exist
# Uncomment below to add tick labels if needed:
# ax.set_xticks(range(len(class_labels)))
# ax.set_xticklabels(class_labels, rotation=45, ha='right')
# ax.set_yticks(range(len(class_labels)))
# ax.set_yticklabels(class_labels)
# Save as vector PDF for publication quality
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
@@ -340,20 +314,6 @@ def save_metrics_report(
oob_score: float = None,
silhouette: float = None,
) -> None:
"""
Save classification metrics to a text file.
Args:
output_path: File path for the metrics report
split_description: Description of how data was split
train_size: Number of training samples
test_size: Number of test samples
accuracy: Overall accuracy score
macro_f1: Macro-averaged F1 score
classification_report_str: Detailed per-class metrics from sklearn
oob_score: Out-of-bag score from Random Forest (optional)
silhouette: Silhouette score for user cluster quality (optional)
"""
with open(output_path, "w", encoding="utf-8") as f:
# Write summary statistics
f.write(f"split_used : {split_description}\n")
@@ -384,18 +344,7 @@ def save_feature_importance_plot(
output_path: str,
top_k: int = 50,
) -> None:
"""
Create and save a feature importance bar chart.
Feature importance in Random Forest is computed based on mean decrease
in impurity (Gini importance). Higher values indicate features that
contribute more to classification decisions.
Args:
feature_importances: Array of importance scores for each feature
output_path: File path to save the plot (PDF format recommended)
top_k: Number of top features to display (default: 50)
"""
# Get indices of top-k most important features
top_indices = np.argsort(feature_importances)[::-1][:top_k]
top_importances = feature_importances[top_indices]
@@ -428,30 +377,14 @@ def save_feature_importance_csv(
feature_importances: np.ndarray,
output_path: str,
) -> None:
"""
Save feature importances to CSV for further analysis.
Args:
feature_importances: Array of importance scores for each feature
output_path: File path for the CSV file
"""
# Create DataFrame with feature names and importances
importance_df = pd.DataFrame({
"feature": [f"emb_{i}" for i in range(len(feature_importances))],
"importance": feature_importances,
})
# Sort by importance (descending)
importance_df = importance_df.sort_values("importance", ascending=False)
importance_df = importance_df.reset_index(drop=True)
# Add rank column
importance_df["rank"] = range(1, len(importance_df) + 1)
# Reorder columns
importance_df = importance_df[["rank", "feature", "importance"]]
# Save to CSV
importance_df.to_csv(output_path, index=False)
print(f"[DONE] Saved feature importance CSV: {output_path}")
@@ -461,181 +394,78 @@ def save_feature_importance_csv(
# Command Line Interface
# =============================================================================
def parse_arguments() -> argparse.Namespace:
"""
Parse command line arguments for the user classification experiment.
Arguments:
--embeddings: Path to embeddings.npy file (required)
--out_dir: Output directory for results (required)
--split_mode: How to split data - "session" or "random" (default: session)
--test_size: Fraction for test set when using random split (default: 0.2)
"""
parser = argparse.ArgumentParser(
description="Train a user classifier on time series embeddings",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--embeddings",
type=str,
required=True,
help="Path to HuggingFace dataset directory containing embeddings",
)
parser.add_argument(
"--out_dir",
type=str,
required=True,
help="Output directory for metrics and confusion matrix",
)
parser.add_argument(
"--split_mode",
type=str,
default="session",
choices=["session", "random"],
help="Data splitting strategy: 'session' (train=1, test=2) or 'random'",
)
parser.add_argument(
"--test_size",
type=float,
default=0.2,
help="Test set fraction when using random split (default: 0.2)",
)
# Random Forest hyperparameters
parser.add_argument(
"--n_estimators",
type=int,
default=200,
help="Number of trees in Random Forest (default: 200)",
)
parser.add_argument(
"--max_depth",
type=int,
default=None,
help="Maximum tree depth (default: None for unlimited)",
)
parser.add_argument(
"--labels",
type=str,
default=None,
help="Comma-separated sleep stage labels to include (e.g., '0,1,2')",
)
return parser.parse_args()
def main(
embeddings: str,
out_dir: str,
split_mode: str = "session",
test_size: float = 0.2,
n_estimators: int = 200,
max_depth: int = None,
labels: str = None,
) -> None:
# =============================================================================
# Main Entry Point
# =============================================================================
def main():
"""
Main entry point for user classification experiment.
Workflow:
1. Load embeddings and metadata
2. Split data (by session or randomly)
3. Train Random Forest classifier
4. Evaluate and save results (metrics, confusion matrix, feature importance)
"""
# Parse command line arguments
args = parse_arguments()
# Validate split_mode argument
if split_mode not in ["session", "random"]:
raise ValueError(f"Invalid split_mode: {split_mode}. Use 'session' or 'random'.")
# Create output directory
os.makedirs(args.out_dir, exist_ok=True)
# =========================================================================
# Step 1: Load Data
# =========================================================================
print(f"[INFO] Loading embeddings from: {args.embeddings}")
dataset = load_embeddings_with_metadata(args.embeddings)
os.makedirs(out_dir, exist_ok=True)
print(f"[INFO] Loading embeddings from: {embeddings}")
dataset = load_embeddings_with_metadata(embeddings)
# Filter by sleep stage labels if specified
if args.labels:
label_list = [l.strip() for l in args.labels.split(",")]
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: str(x["label"]) in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
# Extract feature matrix (embeddings as numpy array)
X = np.array(dataset["embedding"], dtype=np.float32)
# Extract labels (user IDs as strings)
y = np.array([str(uid) for uid in dataset["user_id"]])
# Extract session IDs for session-based splitting
session_ids = np.array([str(sid) for sid in dataset["session_id"]])
print(f"[INFO] Loaded {len(y)} samples with {X.shape[1]}-dimensional embeddings")
print(f"[INFO] Number of unique users: {len(np.unique(y))}")
# =========================================================================
# Step 2: Split Data
# =========================================================================
if args.split_mode == "session":
# Check if both sessions exist
if split_mode == "session":
has_session_1 = (session_ids == "1").sum() > 0
has_session_2 = (session_ids == "2").sum() > 0
if has_session_1 and has_session_2:
# Perform session-based split
X_train, X_test, y_train, y_test, split_desc = split_by_session(
X, y, session_ids
)
else:
# Fallback to random split if sessions are missing
print("[WARN] Missing session data, falling back to random split")
X_train, X_test, y_train, y_test, split_desc = split_random(
X, y, args.test_size
X, y, test_size
)
split_desc = "random(fallback)"
else:
# Random split
X_train, X_test, y_train, y_test, split_desc = split_random(
X, y, args.test_size
X, y, test_size
)
print(f"[INFO] Split: {split_desc}")
print(f"[INFO] Train size: {len(y_train)}, Test size: {len(y_test)}")
# Compute silhouette score on full dataset
# Higher score (closer to 1) = better user separation in embedding space
silhouette_avg = silhouette_score(X, y)
print(f"[INFO] Silhouette Score (user clusters): {silhouette_avg:.4f}")
# =========================================================================
# Step 3: Train Classifier
# =========================================================================
print("[INFO] Training Random Forest classifier...")
print(f"[INFO] Hyperparameters: n_estimators={args.n_estimators}, max_depth={args.max_depth}")
print(f"[INFO] Hyperparameters: n_estimators={n_estimators}, max_depth={max_depth}")
classifier = create_classifier_pipeline(
n_estimators=args.n_estimators,
max_depth=args.max_depth,
n_estimators=n_estimators,
max_depth=max_depth,
)
classifier.fit(X_train, y_train)
# Generate predictions on test set
y_pred = classifier.predict(X_test)
# Get Random Forest specific metrics
rf_model = classifier.named_steps["classifier"]
oob_score = rf_model.oob_score_ if hasattr(rf_model, "oob_score_") else None
feature_importances = rf_model.feature_importances_
# =========================================================================
# Step 4: Evaluate and Save Results
# =========================================================================
print("[INFO] Evaluating classifier performance...")
accuracy, macro_f1, report, cm, classes = evaluate_classifier(y_test, y_pred)
# Save metrics report (with OOB score and silhouette)
metrics_path = os.path.join(args.out_dir, "user_cls_metrics.txt")
metrics_path = os.path.join(out_dir, "user_cls_metrics.txt")
save_metrics_report(
output_path=metrics_path,
split_description=split_desc,
@@ -648,18 +478,15 @@ def main():
silhouette=silhouette_avg,
)
# Save confusion matrix plot
confusion_path = os.path.join(args.out_dir, "user_cls_confusion.pdf")
confusion_path = os.path.join(out_dir, "user_cls_confusion.pdf")
save_confusion_matrix_plot(cm, classes, confusion_path)
# Save feature importance analysis
importance_plot_path = os.path.join(args.out_dir, "feature_importance.pdf")
importance_plot_path = os.path.join(out_dir, "feature_importance.pdf")
save_feature_importance_plot(feature_importances, importance_plot_path)
importance_csv_path = os.path.join(args.out_dir, "feature_importance.csv")
importance_csv_path = os.path.join(out_dir, "feature_importance.csv")
save_feature_importance_csv(feature_importances, importance_csv_path)
# Print summary to console
print("\n" + "=" * 50)
print("RANDOM FOREST CLASSIFICATION RESULTS")
print("=" * 50)
@@ -677,4 +504,4 @@ def main():
if __name__ == "__main__":
main()
Fire(main)