Add plots and classification results using lib
This commit is contained in:
@@ -37,7 +37,7 @@ Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from glob import glob
|
||||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -45,6 +45,7 @@ import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from datasets import load_from_disk, Dataset
|
||||
from chronos import BaseChronosPipeline, Chronos2Pipeline
|
||||
from fire import Fire
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
# pipeline: Chronos2Pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cuda")
|
||||
@@ -125,24 +126,27 @@ class Chronos_2_Embedder:
|
||||
|
||||
@staticmethod
|
||||
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
Discover all user/session directories under data_root.
|
||||
|
||||
Uses glob pattern matching for cleaner directory traversal.
|
||||
Expected structure: data_root/user_id/session_id/
|
||||
|
||||
Returns:
|
||||
List of (user_id, session_id, session_path) tuples
|
||||
"""
|
||||
discovered_paths = []
|
||||
|
||||
# Iterate through user directories
|
||||
for user_id in sorted(os.listdir(data_root)):
|
||||
user_path = os.path.join(data_root, user_id)
|
||||
|
||||
# Skip non-directory entries
|
||||
if not os.path.isdir(user_path):
|
||||
# Use glob to find all session directories (2 levels deep)
|
||||
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
|
||||
if not os.path.isdir(session_path):
|
||||
continue
|
||||
|
||||
# Iterate through session directories within each user
|
||||
for session_id in sorted(os.listdir(user_path)):
|
||||
session_path = os.path.join(user_path, session_id)
|
||||
|
||||
if not os.path.isdir(session_path):
|
||||
continue
|
||||
|
||||
discovered_paths.append((user_id, session_id, session_path))
|
||||
# Extract user_id and session_id from path
|
||||
session_id = os.path.basename(session_path)
|
||||
user_id = os.path.basename(os.path.dirname(session_path))
|
||||
|
||||
discovered_paths.append((user_id, session_id, session_path))
|
||||
|
||||
return discovered_paths
|
||||
|
||||
@@ -284,15 +288,6 @@ class Chronos_2_Embedder:
|
||||
dataset: Dataset,
|
||||
output_dir: str
|
||||
) -> None:
|
||||
"""
|
||||
Save embeddings dataset to disk using HuggingFace format.
|
||||
|
||||
Uses Arrow format for efficient storage and memory-mapped loading.
|
||||
|
||||
Args:
|
||||
dataset: HuggingFace Dataset with embeddings and metadata
|
||||
output_dir: Directory to save the dataset (created if not exists)
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
dataset.save_to_disk(output_dir)
|
||||
@@ -302,15 +297,6 @@ class Chronos_2_Embedder:
|
||||
|
||||
@staticmethod
|
||||
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||
"""
|
||||
Load previously saved embeddings dataset.
|
||||
|
||||
Args:
|
||||
embedding_dir: Directory containing the HuggingFace dataset
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with embeddings and metadata
|
||||
"""
|
||||
dataset = load_from_disk(embedding_dir)
|
||||
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
|
||||
return dataset
|
||||
@@ -409,131 +395,79 @@ def create_scatter_plot(
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
"""
|
||||
Parse command line arguments.
|
||||
|
||||
Supports two subcommands:
|
||||
- extract: Generate embeddings from raw data
|
||||
- plot: Visualize existing embeddings
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Chronos-2 Time Series Embedding Pipeline",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
extract_parser = subparsers.add_parser(
|
||||
"extract",
|
||||
help="Extract embeddings from time series data"
|
||||
)
|
||||
extract_parser.add_argument(
|
||||
"--data_root", type=str, required=True,
|
||||
help="Root directory containing user/session data folders"
|
||||
)
|
||||
extract_parser.add_argument(
|
||||
"--out_dir", type=str, required=True,
|
||||
help="Output directory for embeddings.npy and metadata.csv"
|
||||
)
|
||||
extract_parser.add_argument(
|
||||
"--model", type=str, default="amazon/chronos-2",
|
||||
help="Chronos-2 model name or path (default: amazon/chronos-2)"
|
||||
)
|
||||
extract_parser.add_argument(
|
||||
"--batch_size", type=int, default=32,
|
||||
help="Batch size for inference (default: 32)"
|
||||
)
|
||||
extract_parser.add_argument(
|
||||
"--pooling", type=str, default="mean",
|
||||
choices=["mean", "cls"],
|
||||
help="Pooling strategy: 'mean' or 'cls' (default: mean)"
|
||||
)
|
||||
|
||||
plot_parser = subparsers.add_parser(
|
||||
"plot",
|
||||
help="Visualize embeddings with t-SNE"
|
||||
)
|
||||
plot_parser.add_argument(
|
||||
"--emb_dir", type=str, required=True,
|
||||
help="Directory containing embeddings.npy and metadata.csv"
|
||||
)
|
||||
plot_parser.add_argument(
|
||||
"--out_dir", type=str, required=True,
|
||||
help="Output directory for visualization plots"
|
||||
)
|
||||
plot_parser.add_argument(
|
||||
"--perplexity", type=float, default=30.0,
|
||||
help="t-SNE perplexity parameter (default: 30.0)"
|
||||
)
|
||||
plot_parser.add_argument(
|
||||
"--users", type=str, default=None,
|
||||
help="Comma-separated user IDs to include (e.g., '00,01,02')"
|
||||
)
|
||||
plot_parser.add_argument(
|
||||
"--num_users", type=int, default=0,
|
||||
help="Include only first N users (0 = all users)"
|
||||
)
|
||||
plot_parser.add_argument(
|
||||
"--labels", type=str, default=None,
|
||||
help="Comma-separated sleep stage labels to include (e.g., '0,1,2')"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the embedding pipeline.
|
||||
|
||||
Dispatches to either embedding extraction or visualization
|
||||
based on the subcommand.
|
||||
"""
|
||||
args = parse_arguments()
|
||||
|
||||
# =========================================================================
|
||||
# Extract Command: Generate embeddings from raw data
|
||||
# =========================================================================
|
||||
if args.command == "extract":
|
||||
# Initialize embedder with specified configuration
|
||||
class CLI:
|
||||
def extract(
|
||||
self,
|
||||
data_root: str,
|
||||
out_dir: str,
|
||||
model: str = "amazon/chronos-2",
|
||||
batch_size: int = 32,
|
||||
pooling: str = "mean",
|
||||
) -> None:
|
||||
"""
|
||||
Extract embeddings from time series data.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
out_dir: Output directory for HuggingFace dataset
|
||||
model: Chronos-2 model name or path (default: amazon/chronos-2)
|
||||
batch_size: Batch size for inference (default: 32)
|
||||
pooling: Pooling strategy - 'mean' or 'cls' (default: mean)
|
||||
"""
|
||||
# Validate pooling argument
|
||||
if pooling not in ["mean", "cls"]:
|
||||
raise ValueError(f"Invalid pooling strategy: {pooling}. Use 'mean' or 'cls'.")
|
||||
|
||||
# Initialize embedder
|
||||
embedder = Chronos_2_Embedder(
|
||||
model_name=args.model,
|
||||
pooling_strategy=args.pooling,
|
||||
model_name=model,
|
||||
pooling_strategy=pooling,
|
||||
)
|
||||
|
||||
# Extract embeddings from all data
|
||||
dataset = embedder.extract_embeddings(
|
||||
args.data_root,
|
||||
args.batch_size
|
||||
)
|
||||
|
||||
# Save results
|
||||
embedder.save_embeddings(dataset, args.out_dir)
|
||||
# Extract and save embeddings
|
||||
dataset = embedder.extract_embeddings(data_root, batch_size)
|
||||
embedder.save_embeddings(dataset, out_dir)
|
||||
|
||||
# =========================================================================
|
||||
# Plot Command: Visualize existing embeddings
|
||||
# =========================================================================
|
||||
elif args.command == "plot":
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
def plot(
|
||||
self,
|
||||
emb_dir: str,
|
||||
out_dir: str,
|
||||
perplexity: float = 30.0,
|
||||
users: str = None,
|
||||
num_users: int = 0,
|
||||
labels: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize embeddings with t-SNE.
|
||||
|
||||
Args:
|
||||
emb_dir: Directory containing the HuggingFace embeddings dataset
|
||||
out_dir: Output directory for visualization plots (PDF)
|
||||
perplexity: t-SNE perplexity parameter (default: 30.0)
|
||||
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||
num_users: Include only first N users, 0 = all (default: 0)
|
||||
labels: Comma-separated sleep stage labels to include (e.g., '0,1,2')
|
||||
"""
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Load saved embeddings dataset
|
||||
dataset = Chronos_2_Embedder.load_embeddings(args.emb_dir)
|
||||
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
|
||||
|
||||
# Apply user filtering if specified
|
||||
if args.users:
|
||||
user_list = [u.strip() for u in args.users.split(",")]
|
||||
# Apply user filtering
|
||||
if users:
|
||||
user_list = [u.strip() for u in users.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
|
||||
print(f"[INFO] Filtered to users: {user_list}")
|
||||
|
||||
elif args.num_users > 0:
|
||||
elif num_users > 0:
|
||||
all_users = sorted(set(dataset["user_id"]))
|
||||
selected_users = all_users[:args.num_users]
|
||||
selected_users = all_users[:num_users]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||
print(f"[INFO] Selected first {args.num_users} users: {selected_users}")
|
||||
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||
|
||||
# Filter by sleep stage labels if specified
|
||||
if args.labels:
|
||||
label_list = [l.strip() for l in args.labels.split(",")]
|
||||
# Filter by sleep stage labels
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["label"] in label_list)
|
||||
print(f"[INFO] Filtered to labels: {label_list}")
|
||||
|
||||
@@ -543,25 +477,23 @@ def main():
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
|
||||
# Reduce to 2D with t-SNE
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, args.perplexity)
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||
|
||||
# Generate visualizations
|
||||
# Plot 1: Color by sleep stage label
|
||||
create_scatter_plot(
|
||||
coordinates_2d,
|
||||
np.array(dataset["label"]),
|
||||
"t-SNE Visualization (Colored by Sleep Stage)",
|
||||
os.path.join(args.out_dir, "tsne_by_label.pdf")
|
||||
os.path.join(out_dir, "tsne_by_label.pdf")
|
||||
)
|
||||
|
||||
# Plot 2: Color by user ID
|
||||
create_scatter_plot(
|
||||
coordinates_2d,
|
||||
np.array(dataset["user_id"]),
|
||||
"t-SNE Visualization (Colored by User ID)",
|
||||
os.path.join(args.out_dir, "tsne_by_user.pdf")
|
||||
os.path.join(out_dir, "tsne_by_user.pdf")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Fire(CLI)
|
||||
|
||||
@@ -43,7 +43,6 @@ Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -51,6 +50,7 @@ import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from datasets import load_from_disk, Dataset
|
||||
from fire import Fire
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
@@ -69,22 +69,6 @@ from sklearn.model_selection import train_test_split
|
||||
# =============================================================================
|
||||
|
||||
def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
|
||||
"""
|
||||
Load embeddings dataset from disk using HuggingFace format.
|
||||
|
||||
Expects a HuggingFace Dataset directory with columns:
|
||||
- user_id, session_id, idx, label (metadata)
|
||||
- embedding (1024-dim vector)
|
||||
|
||||
Args:
|
||||
embedding_path: Path to the HuggingFace dataset directory
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with embeddings and metadata
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If dataset directory is not found
|
||||
"""
|
||||
if not os.path.isdir(embedding_path):
|
||||
raise FileNotFoundError(
|
||||
f"Dataset directory not found: {embedding_path}. "
|
||||
@@ -96,7 +80,6 @@ def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Splitting
|
||||
# =============================================================================
|
||||
@@ -250,7 +233,7 @@ def evaluate_classifier(
|
||||
|
||||
Metrics Computed:
|
||||
- Accuracy: Overall fraction of correct predictions
|
||||
- Macro F1: Average F1 across all classes (treats all classes equally)
|
||||
- Macro F1: Average F1 across all classes
|
||||
- Per-class report: Precision, recall, F1 for each user
|
||||
- Confusion matrix: Detailed breakdown of predictions vs ground truth
|
||||
|
||||
@@ -268,8 +251,7 @@ def evaluate_classifier(
|
||||
# Generate detailed per-class report
|
||||
report = classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
# Compute confusion matrix
|
||||
# Get all unique classes from both true and predicted labels
|
||||
# Compute confusion matrix: Get all unique classes from both true and predicted labels
|
||||
class_labels = sorted(set(y_true) | set(y_pred))
|
||||
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
|
||||
|
||||
@@ -314,14 +296,6 @@ def save_confusion_matrix_plot(
|
||||
ax.set_xlabel("Predicted User")
|
||||
ax.set_ylabel("True User")
|
||||
|
||||
# Note: Axis ticks omitted for readability when many classes exist
|
||||
# Uncomment below to add tick labels if needed:
|
||||
# ax.set_xticks(range(len(class_labels)))
|
||||
# ax.set_xticklabels(class_labels, rotation=45, ha='right')
|
||||
# ax.set_yticks(range(len(class_labels)))
|
||||
# ax.set_yticklabels(class_labels)
|
||||
|
||||
# Save as vector PDF for publication quality
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format="pdf", bbox_inches="tight")
|
||||
plt.close()
|
||||
@@ -340,20 +314,6 @@ def save_metrics_report(
|
||||
oob_score: float = None,
|
||||
silhouette: float = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save classification metrics to a text file.
|
||||
|
||||
Args:
|
||||
output_path: File path for the metrics report
|
||||
split_description: Description of how data was split
|
||||
train_size: Number of training samples
|
||||
test_size: Number of test samples
|
||||
accuracy: Overall accuracy score
|
||||
macro_f1: Macro-averaged F1 score
|
||||
classification_report_str: Detailed per-class metrics from sklearn
|
||||
oob_score: Out-of-bag score from Random Forest (optional)
|
||||
silhouette: Silhouette score for user cluster quality (optional)
|
||||
"""
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
# Write summary statistics
|
||||
f.write(f"split_used : {split_description}\n")
|
||||
@@ -384,18 +344,7 @@ def save_feature_importance_plot(
|
||||
output_path: str,
|
||||
top_k: int = 50,
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a feature importance bar chart.
|
||||
|
||||
Feature importance in Random Forest is computed based on mean decrease
|
||||
in impurity (Gini importance). Higher values indicate features that
|
||||
contribute more to classification decisions.
|
||||
|
||||
Args:
|
||||
feature_importances: Array of importance scores for each feature
|
||||
output_path: File path to save the plot (PDF format recommended)
|
||||
top_k: Number of top features to display (default: 50)
|
||||
"""
|
||||
|
||||
# Get indices of top-k most important features
|
||||
top_indices = np.argsort(feature_importances)[::-1][:top_k]
|
||||
top_importances = feature_importances[top_indices]
|
||||
@@ -428,30 +377,14 @@ def save_feature_importance_csv(
|
||||
feature_importances: np.ndarray,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Save feature importances to CSV for further analysis.
|
||||
|
||||
Args:
|
||||
feature_importances: Array of importance scores for each feature
|
||||
output_path: File path for the CSV file
|
||||
"""
|
||||
# Create DataFrame with feature names and importances
|
||||
importance_df = pd.DataFrame({
|
||||
"feature": [f"emb_{i}" for i in range(len(feature_importances))],
|
||||
"importance": feature_importances,
|
||||
})
|
||||
|
||||
# Sort by importance (descending)
|
||||
importance_df = importance_df.sort_values("importance", ascending=False)
|
||||
importance_df = importance_df.reset_index(drop=True)
|
||||
|
||||
# Add rank column
|
||||
importance_df["rank"] = range(1, len(importance_df) + 1)
|
||||
|
||||
# Reorder columns
|
||||
importance_df = importance_df[["rank", "feature", "importance"]]
|
||||
|
||||
# Save to CSV
|
||||
importance_df.to_csv(output_path, index=False)
|
||||
|
||||
print(f"[DONE] Saved feature importance CSV: {output_path}")
|
||||
@@ -461,181 +394,78 @@ def save_feature_importance_csv(
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
"""
|
||||
Parse command line arguments for the user classification experiment.
|
||||
|
||||
Arguments:
|
||||
--embeddings: Path to embeddings.npy file (required)
|
||||
--out_dir: Output directory for results (required)
|
||||
--split_mode: How to split data - "session" or "random" (default: session)
|
||||
--test_size: Fraction for test set when using random split (default: 0.2)
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a user classifier on time series embeddings",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embeddings",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to HuggingFace dataset directory containing embeddings",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory for metrics and confusion matrix",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split_mode",
|
||||
type=str,
|
||||
default="session",
|
||||
choices=["session", "random"],
|
||||
help="Data splitting strategy: 'session' (train=1, test=2) or 'random'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test_size",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Test set fraction when using random split (default: 0.2)",
|
||||
)
|
||||
|
||||
# Random Forest hyperparameters
|
||||
parser.add_argument(
|
||||
"--n_estimators",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Number of trees in Random Forest (default: 200)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_depth",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum tree depth (default: None for unlimited)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--labels",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated sleep stage labels to include (e.g., '0,1,2')",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
def main(
|
||||
embeddings: str,
|
||||
out_dir: str,
|
||||
split_mode: str = "session",
|
||||
test_size: float = 0.2,
|
||||
n_estimators: int = 200,
|
||||
max_depth: int = None,
|
||||
labels: str = None,
|
||||
) -> None:
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Entry Point
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for user classification experiment.
|
||||
|
||||
Workflow:
|
||||
1. Load embeddings and metadata
|
||||
2. Split data (by session or randomly)
|
||||
3. Train Random Forest classifier
|
||||
4. Evaluate and save results (metrics, confusion matrix, feature importance)
|
||||
"""
|
||||
# Parse command line arguments
|
||||
args = parse_arguments()
|
||||
# Validate split_mode argument
|
||||
if split_mode not in ["session", "random"]:
|
||||
raise ValueError(f"Invalid split_mode: {split_mode}. Use 'session' or 'random'.")
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Load Data
|
||||
# =========================================================================
|
||||
print(f"[INFO] Loading embeddings from: {args.embeddings}")
|
||||
dataset = load_embeddings_with_metadata(args.embeddings)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
print(f"[INFO] Loading embeddings from: {embeddings}")
|
||||
dataset = load_embeddings_with_metadata(embeddings)
|
||||
|
||||
# Filter by sleep stage labels if specified
|
||||
if args.labels:
|
||||
label_list = [l.strip() for l in args.labels.split(",")]
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
dataset = dataset.filter(lambda x: str(x["label"]) in label_list)
|
||||
print(f"[INFO] Filtered to labels: {label_list}")
|
||||
|
||||
# Extract feature matrix (embeddings as numpy array)
|
||||
X = np.array(dataset["embedding"], dtype=np.float32)
|
||||
|
||||
# Extract labels (user IDs as strings)
|
||||
y = np.array([str(uid) for uid in dataset["user_id"]])
|
||||
|
||||
# Extract session IDs for session-based splitting
|
||||
session_ids = np.array([str(sid) for sid in dataset["session_id"]])
|
||||
|
||||
print(f"[INFO] Loaded {len(y)} samples with {X.shape[1]}-dimensional embeddings")
|
||||
print(f"[INFO] Number of unique users: {len(np.unique(y))}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Split Data
|
||||
# =========================================================================
|
||||
if args.split_mode == "session":
|
||||
# Check if both sessions exist
|
||||
if split_mode == "session":
|
||||
has_session_1 = (session_ids == "1").sum() > 0
|
||||
has_session_2 = (session_ids == "2").sum() > 0
|
||||
|
||||
if has_session_1 and has_session_2:
|
||||
# Perform session-based split
|
||||
X_train, X_test, y_train, y_test, split_desc = split_by_session(
|
||||
X, y, session_ids
|
||||
)
|
||||
else:
|
||||
# Fallback to random split if sessions are missing
|
||||
print("[WARN] Missing session data, falling back to random split")
|
||||
X_train, X_test, y_train, y_test, split_desc = split_random(
|
||||
X, y, args.test_size
|
||||
X, y, test_size
|
||||
)
|
||||
split_desc = "random(fallback)"
|
||||
else:
|
||||
# Random split
|
||||
X_train, X_test, y_train, y_test, split_desc = split_random(
|
||||
X, y, args.test_size
|
||||
X, y, test_size
|
||||
)
|
||||
|
||||
print(f"[INFO] Split: {split_desc}")
|
||||
print(f"[INFO] Train size: {len(y_train)}, Test size: {len(y_test)}")
|
||||
|
||||
# Compute silhouette score on full dataset
|
||||
# Higher score (closer to 1) = better user separation in embedding space
|
||||
silhouette_avg = silhouette_score(X, y)
|
||||
print(f"[INFO] Silhouette Score (user clusters): {silhouette_avg:.4f}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Train Classifier
|
||||
# =========================================================================
|
||||
print("[INFO] Training Random Forest classifier...")
|
||||
print(f"[INFO] Hyperparameters: n_estimators={args.n_estimators}, max_depth={args.max_depth}")
|
||||
print(f"[INFO] Hyperparameters: n_estimators={n_estimators}, max_depth={max_depth}")
|
||||
|
||||
classifier = create_classifier_pipeline(
|
||||
n_estimators=args.n_estimators,
|
||||
max_depth=args.max_depth,
|
||||
n_estimators=n_estimators,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
classifier.fit(X_train, y_train)
|
||||
|
||||
# Generate predictions on test set
|
||||
y_pred = classifier.predict(X_test)
|
||||
|
||||
# Get Random Forest specific metrics
|
||||
rf_model = classifier.named_steps["classifier"]
|
||||
oob_score = rf_model.oob_score_ if hasattr(rf_model, "oob_score_") else None
|
||||
feature_importances = rf_model.feature_importances_
|
||||
|
||||
# =========================================================================
|
||||
# Step 4: Evaluate and Save Results
|
||||
# =========================================================================
|
||||
print("[INFO] Evaluating classifier performance...")
|
||||
accuracy, macro_f1, report, cm, classes = evaluate_classifier(y_test, y_pred)
|
||||
|
||||
# Save metrics report (with OOB score and silhouette)
|
||||
metrics_path = os.path.join(args.out_dir, "user_cls_metrics.txt")
|
||||
metrics_path = os.path.join(out_dir, "user_cls_metrics.txt")
|
||||
save_metrics_report(
|
||||
output_path=metrics_path,
|
||||
split_description=split_desc,
|
||||
@@ -648,18 +478,15 @@ def main():
|
||||
silhouette=silhouette_avg,
|
||||
)
|
||||
|
||||
# Save confusion matrix plot
|
||||
confusion_path = os.path.join(args.out_dir, "user_cls_confusion.pdf")
|
||||
confusion_path = os.path.join(out_dir, "user_cls_confusion.pdf")
|
||||
save_confusion_matrix_plot(cm, classes, confusion_path)
|
||||
|
||||
# Save feature importance analysis
|
||||
importance_plot_path = os.path.join(args.out_dir, "feature_importance.pdf")
|
||||
importance_plot_path = os.path.join(out_dir, "feature_importance.pdf")
|
||||
save_feature_importance_plot(feature_importances, importance_plot_path)
|
||||
|
||||
importance_csv_path = os.path.join(args.out_dir, "feature_importance.csv")
|
||||
importance_csv_path = os.path.join(out_dir, "feature_importance.csv")
|
||||
save_feature_importance_csv(feature_importances, importance_csv_path)
|
||||
|
||||
# Print summary to console
|
||||
print("\n" + "=" * 50)
|
||||
print("RANDOM FOREST CLASSIFICATION RESULTS")
|
||||
print("=" * 50)
|
||||
@@ -677,4 +504,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Fire(main)
|
||||
|
||||
Reference in New Issue
Block a user