Merge branch 'diamond264:user_similarity_analysis' into user_similarity_analysis

This commit is contained in:
SSUM
2026-01-13 18:21:54 +09:00
committed by GitHub
4 changed files with 1039 additions and 0 deletions

6
.gitignore vendored
View File

@@ -1,3 +1,9 @@
.pdf
.txt
.csv
.arrow
.json
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]

View File

@@ -265,6 +265,8 @@ class Chronos_2_Embedder:
for user_id, session_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# shuffle dataset
dataset = dataset.shuffle(seed=0)
num_samples = len(dataset)
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")

View File

@@ -0,0 +1,504 @@
"""
SBERT Time Series Embedding Extraction and Visualization Pipeline
This module provides functionality to
1. Extract embeddings from time series features using SBERT (Sentence-BERT)
2. Visualize embeddings using dimensionality reduction (t-SNE)
SBERT Overview:
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
and triplet network structures to derive semantically meaningful sentence
embeddings. It's designed for semantic similarity tasks and produces
fixed-size dense vector representations of text.
Key Features:
- Semantic understanding: Captures meaning rather than just word presence
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
- Efficient: Optimized for sentence-level tasks
Embedding Strategy:
We convert time series features into textual descriptions, then use SBERT
to generate embeddings. This approach treats feature vectors as "sentences"
where each feature-value pair is a "word" in the description.
Processing Pipeline:
1. Feature extraction from time series (statistical features)
2. Textualization: Convert features to natural language description
3. SBERT encoding: Generate 384-dimensional embeddings
4. Aggregation: Store embeddings with metadata (user_id, session_id, label)
Usage:
# Extract embeddings
python gen_plot.py extract --data_root /path/to/data --out_dir ./embeddings
# Visualize with t-SNE
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Tuple
import numpy as np
import matplotlib.pyplot as plt
import torch
from datasets import load_from_disk, Dataset
from fire import Fire
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
# =============================================================================
# Constants
# =============================================================================
# EEG channel names used in Sleep-EDF dataset
# Fpz-Cz: Frontal-Central electrode pair
# Pz-Oz: Parietal-Occipital electrode pair
EEG_CHANNEL_1 = "EEG Fpz-Cz"
EEG_CHANNEL_2 = "EEG Pz-Oz"
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class SBERT:
"""
Extracts fixed-dimensional embeddings from time series features using SBERT.
Uses Sentence-BERT to convert textualized feature descriptions into dense
vector representations. The textualization process converts statistical
features into natural language, which SBERT then encodes semantically.
Architecture:
Time Series → Feature Extraction → Textualization → SBERT Encoder → Embedding
Processing Pipeline:
1. Extract statistical features from time series (mean, std, etc.)
2. Textualize: Convert features to natural language description
3. SBERT encoding: Generate 384-dimensional semantic embeddings
4. Output: Fixed-size embedding vector per sample
Model: all-MiniLM-L6-v2
- Lightweight BERT variant optimized for sentence embeddings
- 384-dimensional output embeddings
- Fast inference with good semantic understanding
Attributes:
model: SentenceTransformer instance for encoding textualized features
"""
def __init__(self):
"""
Initialize the SBERT embedder with pre-trained model.
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
optimized for sentence similarity tasks.
"""
self.model = SentenceTransformer("all-MiniLM-L6-v2")
@staticmethod
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
"""
Discover all user/session directories under data_root.
Uses glob pattern matching for cleaner directory traversal.
Expected structure: data_root/user_id/session_id/
Args:
data_root: Root directory containing user/session subfolders
Returns:
List of (user_id, session_id, session_path) tuples
"""
discovered_paths = []
# Use glob to find all session directories (2 levels deep)
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
if not os.path.isdir(session_path):
continue
# Extract user_id and session_id from path
session_id = os.path.basename(session_path)
user_id = os.path.basename(os.path.dirname(session_path))
discovered_paths.append((user_id, session_id, session_path))
return discovered_paths
def textualize_sample(self, sample: Dict[str, Any]) -> str:
"""
Convert a feature dictionary into a natural language description.
This textualization step is crucial for SBERT, which expects text input.
The description provides context about the sleep stage and lists all
extracted features in a structured format.
Args:
sample: Dictionary mapping feature names to their values
Example: {"mean": 0.5, "std": 0.2, "max": 1.0, ...}
Returns:
Natural language string describing the features
"""
sentence = (
"While sleeping (one of the stages: W, N1, N2, N3, REM), "
"I have sensor data features measured from two channels: EEG Fpz-Cz and EEG Pz-Oz.\n"
"The features are as follows: \n"
)
for k, v in sample.items():
sentence += f" - {k}: {self.format_feature(v)}\n"
sentence.strip()
return sentence
def format_feature(self, value: Any) -> str:
"""
Format a feature value for textual representation.
Floats are rounded to 2 decimal places for readability,
other types are converted to strings.
Args:
value: Feature value (float, int, or other type)
Returns:
Formatted string representation
"""
if isinstance(value, float):
return f"{value:.2f}"
return str(value)
def compute_embedding(self, batch: Dict[str, Any]) -> np.ndarray:
"""
Generate embedding vectors from feature batches using SBERT.
Processing Pipeline:
1. Extract feature dictionaries from batch
2. Textualize each sample's features into natural language
3. Encode textual descriptions using SBERT
4. Return fixed-size embedding vectors
Args:
batch: HuggingFace dataset batch from slicing (dataset[start:end])
Format: {"features": [{"mean": 0.5, "std": 0.2, ...}, ...]}
Returns:
Embedding array of shape (batch_size, embedding_dim)
For all-MiniLM-L6-v2: (batch_size, 384)
"""
# Extract feature dictionaries from batch
samples = batch["features"]
# Convert each sample's features to text
text_samples = []
for sample in samples:
text_samples.append(self.textualize_sample(sample))
# Encode text descriptions using SBERT
# Returns numpy array of shape (batch_size, 384)
embeddings = self.model.encode(text_samples)
return embeddings
def extract_embeddings(
self,
data_root: str,
batch_size: int = 32,
label: str = "N1"
) -> Dataset:
"""
Extract embeddings from all sessions under the data root directory.
Iterates through all user/session combinations, processes features
in batches, and aggregates results with metadata. Filters by sleep
stage label to focus on specific sleep stages.
Args:
data_root: Root directory containing user/session data folders
batch_size: Number of samples to process together.
Larger = faster but more memory.
32 is a good balance for most systems.
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM")
Returns:
HuggingFace Dataset with columns:
- user_id, session_id, idx, label (metadata)
- embedding (384-dim vector from all-MiniLM-L6-v2)
"""
session_paths = self.discover_session_paths(data_root)
print(f"[INFO] Discovered {len(session_paths)} sessions")
all_embeddings = []
all_user_ids = []
all_session_ids = []
all_idxs = []
all_labels = []
for user_id, session_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# Shuffle dataset for randomness
dataset = dataset.shuffle(seed=0)
# Filter by sleep stage label
dataset = dataset.filter(lambda x: x["label"] == label)
num_samples = len(dataset)
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Process in batches to manage memory
for batch_start in range(0, num_samples, batch_size):
batch_end = min(batch_start + batch_size, num_samples)
# Slice dataset to get batch
batch = dataset[batch_start:batch_end]
# Compute embeddings
embeddings = self.compute_embedding(batch)
# Collect embeddings and metadata
for i in range(embeddings.shape[0]):
all_embeddings.append(embeddings[i].tolist())
all_user_ids.append(str(batch["user_id"][i]))
all_session_ids.append(str(batch["session_id"][i]))
all_idxs.append(int(batch["idx"][i]))
all_labels.append(str(batch["label"][i]))
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
"session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"embedding": all_embeddings,
})
print(f"[INFO] Total samples: {len(result_dataset)}")
return result_dataset
def save_embeddings(
self,
dataset: Dataset,
output_dir: str
) -> None:
"""
Save embeddings dataset to disk in HuggingFace format.
Args:
dataset: HuggingFace Dataset containing embeddings and metadata
output_dir: Directory path to save the dataset
"""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load saved embeddings dataset from disk.
Args:
embedding_dir: Directory path containing the saved HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
dimensionality reduction technique that preserves local structure.
Points that are similar in high dimensions stay close in 2D.
Args:
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity parameter (typically 5-50).
Higher values consider more neighbors, creating smoother layouts.
Rule of thumb: perplexity ~ sqrt(num_samples)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0, # For reproducibility
perplexity=perplexity,
max_iter=1000, # Usually sufficient for convergence
init='random',
learning_rate='auto', # Let sklearn choose optimal learning rate
)
return tsne.fit_transform(embeddings)
def create_scatter_plot(
coordinates: np.ndarray,
labels: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot with categorical coloring.
Args:
coordinates: 2D array of shape (num_points, 2)
labels: Category labels for each point (string array)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Get unique labels for legend
unique_labels = sorted(set(labels))
# Select colormap based on number of categories
# tab10: 10 distinct colors, tab20: 20 distinct colors
colormap = plt.cm.tab10 if len(unique_labels) <= 10 else plt.cm.tab20
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for proper legend
for idx, label in enumerate(unique_labels):
mask = labels == label
ax.scatter(
coordinates[mask, 0],
coordinates[mask, 1],
c=[colormap(idx % 20)],
s=15,
label=label,
alpha=0.7,
)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(loc='best', markerscale=2)
# Save figure as vector PDF (scalable, ideal for publications)
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT embedding extraction and visualization.
Provides two main commands:
- extract: Generate embeddings from time series features
- plot: Visualize embeddings with t-SNE
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
out_dir: str = "./embeddings/W",
batch_size: int = 32,
label: str = "W"
) -> None:
"""
Extract embeddings from time series data.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM')
"""
embedder = SBERT()
dataset = embedder.extract_embeddings(data_root, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = "./embeddings/W",
out_dir: str = "./plots/W",
perplexity: float = 10.0,
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE.
Args:
emb_dir: Directory containing the HuggingFace embeddings dataset
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 10.0)
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
"""
os.makedirs(out_dir, exist_ok=True)
# Load saved embeddings dataset
dataset = SBERT.load_embeddings(emb_dir)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations
create_scatter_plot(
coordinates_2d,
np.array(dataset["label"]),
"t-SNE Visualization (Colored by Sleep Stage)",
os.path.join(out_dir, "tsne_by_label.pdf")
)
create_scatter_plot(
coordinates_2d,
np.array(dataset["user_id"]),
"t-SNE Visualization (Colored by User ID)",
os.path.join(out_dir, "tsne_by_user.pdf")
)
if __name__ == "__main__":
Fire(CLI)

View File

@@ -0,0 +1,527 @@
"""
User Classification from SBERT Embeddings
This module evaluates how well SBERT embeddings capture user-specific patterns
by training a simple linear classifier to predict user identity from embeddings.
Motivation:
If embeddings contain user-distinguishing information, a classifier should be able
to predict which user a time series belongs to. High accuracy suggests that the
embeddings capture individual characteristics.
Experimental Design:
- Task: Multi-class classification
- Model: Random Forest with ensemble of decision trees
- Captures non-linear relationships in embedding space
- Provides feature importance scores for interpretability
- Robust to overfitting through bagging and random feature selection
- Split Strategy:
1. Session-based: Train on session 1, test on session 2
2. Random: Standard train/test split
Session-based split is more challenging and realistic because:
- Tests whether user patterns are stable across different recording sessions
- Avoids data leakage from same-session samples in train and test
Random Forest Advantages over Logistic Regression:
- Handles non-linear decision boundaries
- Feature importance reveals which embedding dimensions matter most
- No assumption about data distribution
- Naturally handles multi-class classification
Output:
- Classification metrics
- Confusion matrix visualization
Usage:
python simple_user_classifier.py \\
--embeddings ./embeddings \\
--out_dir ./results
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from typing import Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset
from fire import Fire
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score,
f1_score,
classification_report,
confusion_matrix,
silhouette_score,
)
from sklearn.model_selection import train_test_split
# =============================================================================
# Data Loading
# =============================================================================
def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
if not os.path.isdir(embedding_path):
raise FileNotFoundError(
f"Dataset directory not found: {embedding_path}. "
"Ensure this is a valid HuggingFace dataset directory."
)
# Load HuggingFace dataset from disk
dataset = load_from_disk(embedding_path)
return dataset
# =============================================================================
# Data Splitting
# =============================================================================
def split_by_session(
features: np.ndarray,
labels: np.ndarray,
session_ids: np.ndarray,
train_session: str = "1",
test_session: str = "2",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
"""
Split data by recording session for temporal generalization evaluation.
This split strategy tests whether learned patterns generalize across time.
Training on session 1 and testing on session 2 simulates real-world deployment
where models must work on future recordings.
Args:
features: Feature matrix of shape (num_samples, num_features)
labels: Label array of shape (num_samples,)
session_ids: Session identifier for each sample
train_session: Session ID to use for training (default: "1")
test_session: Session ID to use for testing (default: "2")
Returns:
Tuple of (X_train, X_test, y_train, y_test, split_description)
"""
# Create boolean masks for train and test sets
train_mask = session_ids == train_session
test_mask = session_ids == test_session
# Apply masks to create train/test splits
X_train = features[train_mask]
X_test = features[test_mask]
y_train = labels[train_mask]
y_test = labels[test_mask]
split_description = f"session({train_session}->train, {test_session}->test)"
return X_train, X_test, y_train, y_test, split_description
def split_random(
features: np.ndarray,
labels: np.ndarray,
test_size: float = 0.2,
random_state: int = 0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
"""
Split data randomly with stratification for in-distribution evaluation.
Stratification ensures each class has proportional representation in
both train and test sets, preventing class imbalance issues.
Args:
features: Feature matrix of shape (num_samples, num_features)
labels: Label array of shape (num_samples,)
test_size: Fraction of data to use for testing (default: 0.2)
random_state: Random seed for reproducibility
Returns:
Tuple of (X_train, X_test, y_train, y_test, split_description)
"""
X_train, X_test, y_train, y_test = train_test_split(
features,
labels,
test_size=test_size,
random_state=random_state,
stratify=labels,
)
split_description = "random"
return X_train, X_test, y_train, y_test, split_description
# =============================================================================
# Model Training and Evaluation
# =============================================================================
def create_classifier_pipeline(
n_estimators: int = 200,
max_depth: int = None,
min_samples_split: int = 2,
min_samples_leaf: int = 1,
random_state: int = 0,
) -> Pipeline:
"""
Create a scikit-learn pipeline for user classification using Random Forest.
Pipeline Architecture:
----------------------
1. StandardScaler: Z-score normalization of features
- Centers features (mean=0) and scales to unit variance (std=1)
- While Random Forest is scale-invariant, scaling helps with
consistent feature importance interpretation
2. RandomForestClassifier: Ensemble of decision trees
- Builds multiple decision trees on random subsets of data (bagging)
- Each tree uses random subset of features at each split
- Final prediction is majority vote across all trees
- Provides feature_importances_ for interpretability
Random Forest Hyperparameters:
------------------------------
- n_estimators: Number of trees in the forest (more = better but slower)
- max_depth: Maximum tree depth (None = expand until pure leaves)
- min_samples_split: Minimum samples to split internal node
- min_samples_leaf: Minimum samples required at leaf node
Args:
n_estimators: Number of trees (default: 200)
max_depth: Maximum depth of trees (default: None, fully grown)
min_samples_split: Min samples for splitting (default: 2)
min_samples_leaf: Min samples at leaf (default: 1)
random_state: Random seed for reproducibility
Returns:
Configured sklearn Pipeline ready for .fit() and .predict()
"""
pipeline = Pipeline([
# Step 1: Feature normalization
("scaler", StandardScaler(
with_mean=True, # Subtract mean (center the data)
with_std=True, # Divide by standard deviation (scale to unit variance)
)),
# Step 2: Random Forest classification
("classifier", RandomForestClassifier(
n_estimators=n_estimators, # Number of trees in the forest
max_depth=max_depth, # Maximum depth (None = unlimited)
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
n_jobs=-1, # Use all CPU cores for parallel tree building
random_state=random_state, # For reproducibility
class_weight="balanced", # Handle class imbalance automatically
oob_score=True, # Enable out-of-bag error estimation
)),
])
return pipeline
def evaluate_classifier(
y_true: np.ndarray,
y_pred: np.ndarray,
) -> Tuple[float, float, str, np.ndarray, list]:
"""
Compute classification metrics and confusion matrix.
Metrics Computed:
- Accuracy: Overall fraction of correct predictions
- Macro F1: Average F1 across all classes
- Per-class report: Precision, recall, F1 for each user
- Confusion matrix: Detailed breakdown of predictions vs ground truth
Args:
y_true: Ground truth labels
y_pred: Predicted labels
Returns:
Tuple of (accuracy, f1, classification_report, confusion_matrix, class_labels)
"""
# Compute scalar metrics
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro') # Multiclass: use macro-averaged F1
# Generate detailed per-class report
report = classification_report(y_true, y_pred, digits=4)
# Compute confusion matrix: Get all unique classes from both true and predicted labels
class_labels = sorted(set(y_true) | set(y_pred))
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
return accuracy, f1, report, cm, class_labels
# =============================================================================
# Visualization
# =============================================================================
def save_confusion_matrix_plot(
confusion_mat: np.ndarray,
class_labels: list,
output_path: str,
) -> None:
"""
Create and save a confusion matrix heatmap visualization.
The confusion matrix shows:
- Rows: True class labels
- Columns: Predicted class labels
- Cell values: Count of samples with that (true, predicted) combination
- Diagonal: Correct predictions
- Off-diagonal: Misclassifications
Args:
confusion_mat: Square matrix of shape (num_classes, num_classes)
class_labels: List of class names for axis labels
output_path: File path to save the plot (PDF format recommended)
"""
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot heatmap using imshow
im = ax.imshow(confusion_mat, aspect="auto", cmap="Blues")
# Add colorbar to show value scale
plt.colorbar(im, ax=ax)
# Set labels and title
ax.set_title("Confusion Matrix (User Classification)")
ax.set_xlabel("Predicted User")
ax.set_ylabel("True User")
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved confusion matrix: {output_path}")
def save_metrics_report(
output_path: str,
split_description: str,
train_size: int,
test_size: int,
accuracy: float,
macro_f1: float,
classification_report_str: str,
oob_score: float = None,
silhouette: float = None,
) -> None:
"""
Save classification metrics to a text file.
Writes summary statistics (split strategy, sizes, scores) and detailed
per-class classification report to a text file.
Args:
output_path: File path to save the metrics report
split_description: Description of train/test split strategy
train_size: Number of training samples
test_size: Number of test samples
accuracy: Overall classification accuracy
macro_f1: Macro-averaged F1 score
classification_report_str: Detailed per-class report string
oob_score: Out-of-bag score from Random Forest (optional)
silhouette: Silhouette score for embedding quality (optional)
"""
with open(output_path, "w", encoding="utf-8") as f:
# Write summary statistics
f.write(f"split_used : {split_description}\n")
f.write(f"train_size : {train_size}\n")
f.write(f"test_size : {test_size}\n")
# Silhouette score (embedding quality metric)
if silhouette is not None:
f.write(f"silhouette : {silhouette:.4f}\n")
f.write(f"accuracy : {accuracy:.4f}\n")
f.write(f"macro_f1 : {macro_f1:.4f}\n")
# Add OOB score if available (Random Forest specific)
if oob_score is not None:
f.write(f"oob_score : {oob_score:.4f}\n")
f.write("\n")
# Write detailed per-class report
f.write(classification_report_str)
print(f"[DONE] Saved metrics report: {output_path}")
def save_feature_importance_plot(
feature_importances: np.ndarray,
output_path: str,
top_k: int = 50,
) -> None:
"""
Create and save a horizontal bar chart of top-k most important features.
Visualizes which embedding dimensions contribute most to user classification.
Higher importance indicates that dimension better distinguishes between users.
Args:
feature_importances: Array of feature importance scores from Random Forest
output_path: File path to save the plot (PDF format recommended)
top_k: Number of top features to display (default: 50)
"""
# Get indices of top-k most important features
top_indices = np.argsort(feature_importances)[::-1][:top_k]
top_importances = feature_importances[top_indices]
# Create figure
fig, ax = plt.subplots(figsize=(12, 8))
# Create horizontal bar chart
y_positions = np.arange(len(top_indices))
ax.barh(y_positions, top_importances, color="steelblue", alpha=0.8)
# Set labels
ax.set_yticks(y_positions)
ax.set_yticklabels([f"emb_{i}" for i in top_indices], fontsize=8)
ax.invert_yaxis() # Highest importance at top
ax.set_xlabel("Feature Importance (Mean Decrease in Impurity)")
ax.set_ylabel("Embedding Dimension")
ax.set_title(f"Top {top_k} Most Important Embedding Dimensions")
# Save as vector PDF
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved feature importance plot: {output_path}")
def save_feature_importance_csv(
feature_importances: np.ndarray,
output_path: str,
) -> None:
"""
Save feature importance scores to CSV file with ranking.
Creates a CSV with columns: rank, feature, importance
Features are sorted by importance in descending order.
Args:
feature_importances: Array of feature importance scores from Random Forest
output_path: File path to save the CSV file
"""
importance_df = pd.DataFrame({
"feature": [f"emb_{i}" for i in range(len(feature_importances))],
"importance": feature_importances,
})
importance_df = importance_df.sort_values("importance", ascending=False)
importance_df = importance_df.reset_index(drop=True)
importance_df["rank"] = range(1, len(importance_df) + 1)
importance_df = importance_df[["rank", "feature", "importance"]]
importance_df.to_csv(output_path, index=False)
print(f"[DONE] Saved feature importance CSV: {output_path}")
# =============================================================================
# Command Line Interface
# =============================================================================
def main(
embeddings: str = "./embeddings/REM",
out_dir: str = "./results/REM",
test_size: float = 0.2,
n_estimators: int = 200,
max_depth: int = None,
) -> None:
# Create output directory
os.makedirs(out_dir, exist_ok=True)
print(f"[INFO] Loading embeddings from: {embeddings}")
dataset = load_embeddings_with_metadata(embeddings)
X = np.array(dataset["embedding"], dtype=np.float32)
y = np.array([str(uid) for uid in dataset["user_id"]])
session_ids = np.array([str(sid) for sid in dataset["session_id"]])
print(f"[INFO] Loaded {len(y)} samples with {X.shape[1]}-dimensional embeddings")
print(f"[INFO] Number of unique users: {len(np.unique(y))}")
has_session_1 = (session_ids == "1").sum() > 0
has_session_2 = (session_ids == "2").sum() > 0
if has_session_1 and has_session_2:
X_train, X_test, y_train, y_test, split_desc = split_by_session(
X, y, session_ids
)
else:
print("[WARN] Missing session data, falling back to random split")
X_train, X_test, y_train, y_test, split_desc = split_random(
X, y, test_size
)
split_desc = "random(fallback)"
print(f"[INFO] Split: {split_desc}")
print(f"[INFO] Train size: {len(y_train)}, Test size: {len(y_test)}")
silhouette_avg = silhouette_score(X, y)
print(f"[INFO] Silhouette Score (user clusters): {silhouette_avg:.4f}")
print("[INFO] Training Random Forest classifier...")
print(f"[INFO] Hyperparameters: n_estimators={n_estimators}, max_depth={max_depth}")
classifier = create_classifier_pipeline(
n_estimators=n_estimators,
max_depth=max_depth,
)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
rf_model = classifier.named_steps["classifier"]
oob_score = rf_model.oob_score_ if hasattr(rf_model, "oob_score_") else None
feature_importances = rf_model.feature_importances_
print("[INFO] Evaluating classifier performance...")
accuracy, macro_f1, report, cm, classes = evaluate_classifier(y_test, y_pred)
metrics_path = os.path.join(out_dir, "user_cls_metrics.txt")
save_metrics_report(
output_path=metrics_path,
split_description=split_desc,
train_size=len(y_train),
test_size=len(y_test),
accuracy=accuracy,
macro_f1=macro_f1,
classification_report_str=report,
oob_score=oob_score,
silhouette=silhouette_avg,
)
confusion_path = os.path.join(out_dir, "user_cls_confusion.pdf")
save_confusion_matrix_plot(cm, classes, confusion_path)
importance_plot_path = os.path.join(out_dir, "feature_importance.pdf")
save_feature_importance_plot(feature_importances, importance_plot_path)
importance_csv_path = os.path.join(out_dir, "feature_importance.csv")
save_feature_importance_csv(feature_importances, importance_csv_path)
print("\n" + "=" * 50)
print("RANDOM FOREST CLASSIFICATION RESULTS")
print("=" * 50)
print(f"Split Strategy : {split_desc}")
print(f"Silhouette : {silhouette_avg:.4f}")
print(f"Accuracy : {accuracy:.4f}")
print(f"Macro F1 : {macro_f1:.4f}")
if oob_score is not None:
print(f"OOB Score : {oob_score:.4f}")
print(f"Top 5 Important Features:")
top5_idx = np.argsort(feature_importances)[::-1][:5]
for rank, idx in enumerate(top5_idx, 1):
print(f" {rank}. emb_{idx}: {feature_importances[idx]:.4f}")
print("=" * 50)
if __name__ == "__main__":
Fire(main)