Merge branch 'diamond264:user_similarity_analysis' into user_similarity_analysis
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,3 +1,9 @@
|
||||
.pdf
|
||||
.txt
|
||||
.csv
|
||||
.arrow
|
||||
.json
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[codz]
|
||||
|
||||
@@ -265,6 +265,8 @@ class Chronos_2_Embedder:
|
||||
for user_id, session_id, session_path in session_paths:
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(session_path)
|
||||
# shuffle dataset
|
||||
dataset = dataset.shuffle(seed=0)
|
||||
num_samples = len(dataset)
|
||||
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
|
||||
|
||||
|
||||
@@ -0,0 +1,504 @@
|
||||
"""
|
||||
SBERT Time Series Embedding Extraction and Visualization Pipeline
|
||||
|
||||
This module provides functionality to
|
||||
1. Extract embeddings from time series features using SBERT (Sentence-BERT)
|
||||
2. Visualize embeddings using dimensionality reduction (t-SNE)
|
||||
|
||||
SBERT Overview:
|
||||
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
|
||||
and triplet network structures to derive semantically meaningful sentence
|
||||
embeddings. It's designed for semantic similarity tasks and produces
|
||||
fixed-size dense vector representations of text.
|
||||
|
||||
Key Features:
|
||||
- Semantic understanding: Captures meaning rather than just word presence
|
||||
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
|
||||
- Efficient: Optimized for sentence-level tasks
|
||||
|
||||
Embedding Strategy:
|
||||
We convert time series features into textual descriptions, then use SBERT
|
||||
to generate embeddings. This approach treats feature vectors as "sentences"
|
||||
where each feature-value pair is a "word" in the description.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Feature extraction from time series (statistical features)
|
||||
2. Textualization: Convert features to natural language description
|
||||
3. SBERT encoding: Generate 384-dimensional embeddings
|
||||
4. Aggregation: Store embeddings with metadata (user_id, session_id, label)
|
||||
|
||||
Usage:
|
||||
# Extract embeddings
|
||||
python gen_plot.py extract --data_root /path/to/data --out_dir ./embeddings
|
||||
|
||||
# Visualize with t-SNE
|
||||
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
|
||||
|
||||
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||
Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from datasets import load_from_disk, Dataset
|
||||
from fire import Fire
|
||||
from sklearn.manifold import TSNE
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
# EEG channel names used in Sleep-EDF dataset
|
||||
# Fpz-Cz: Frontal-Central electrode pair
|
||||
# Pz-Oz: Parietal-Occipital electrode pair
|
||||
EEG_CHANNEL_1 = "EEG Fpz-Cz"
|
||||
EEG_CHANNEL_2 = "EEG Pz-Oz"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Extractor Class
|
||||
# =============================================================================
|
||||
|
||||
class SBERT:
|
||||
"""
|
||||
Extracts fixed-dimensional embeddings from time series features using SBERT.
|
||||
|
||||
Uses Sentence-BERT to convert textualized feature descriptions into dense
|
||||
vector representations. The textualization process converts statistical
|
||||
features into natural language, which SBERT then encodes semantically.
|
||||
|
||||
Architecture:
|
||||
Time Series → Feature Extraction → Textualization → SBERT Encoder → Embedding
|
||||
|
||||
Processing Pipeline:
|
||||
1. Extract statistical features from time series (mean, std, etc.)
|
||||
2. Textualize: Convert features to natural language description
|
||||
3. SBERT encoding: Generate 384-dimensional semantic embeddings
|
||||
4. Output: Fixed-size embedding vector per sample
|
||||
|
||||
Model: all-MiniLM-L6-v2
|
||||
- Lightweight BERT variant optimized for sentence embeddings
|
||||
- 384-dimensional output embeddings
|
||||
- Fast inference with good semantic understanding
|
||||
|
||||
Attributes:
|
||||
model: SentenceTransformer instance for encoding textualized features
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the SBERT embedder with pre-trained model.
|
||||
|
||||
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
|
||||
optimized for sentence similarity tasks.
|
||||
"""
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
@staticmethod
|
||||
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
Discover all user/session directories under data_root.
|
||||
|
||||
Uses glob pattern matching for cleaner directory traversal.
|
||||
Expected structure: data_root/user_id/session_id/
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session subfolders
|
||||
|
||||
Returns:
|
||||
List of (user_id, session_id, session_path) tuples
|
||||
"""
|
||||
discovered_paths = []
|
||||
|
||||
# Use glob to find all session directories (2 levels deep)
|
||||
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
|
||||
if not os.path.isdir(session_path):
|
||||
continue
|
||||
|
||||
# Extract user_id and session_id from path
|
||||
session_id = os.path.basename(session_path)
|
||||
user_id = os.path.basename(os.path.dirname(session_path))
|
||||
|
||||
discovered_paths.append((user_id, session_id, session_path))
|
||||
|
||||
return discovered_paths
|
||||
|
||||
def textualize_sample(self, sample: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Convert a feature dictionary into a natural language description.
|
||||
|
||||
This textualization step is crucial for SBERT, which expects text input.
|
||||
The description provides context about the sleep stage and lists all
|
||||
extracted features in a structured format.
|
||||
|
||||
Args:
|
||||
sample: Dictionary mapping feature names to their values
|
||||
Example: {"mean": 0.5, "std": 0.2, "max": 1.0, ...}
|
||||
|
||||
Returns:
|
||||
Natural language string describing the features
|
||||
"""
|
||||
sentence = (
|
||||
"While sleeping (one of the stages: W, N1, N2, N3, REM), "
|
||||
"I have sensor data features measured from two channels: EEG Fpz-Cz and EEG Pz-Oz.\n"
|
||||
"The features are as follows: \n"
|
||||
)
|
||||
for k, v in sample.items():
|
||||
sentence += f" - {k}: {self.format_feature(v)}\n"
|
||||
sentence.strip()
|
||||
return sentence
|
||||
|
||||
def format_feature(self, value: Any) -> str:
|
||||
"""
|
||||
Format a feature value for textual representation.
|
||||
|
||||
Floats are rounded to 2 decimal places for readability,
|
||||
other types are converted to strings.
|
||||
|
||||
Args:
|
||||
value: Feature value (float, int, or other type)
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if isinstance(value, float):
|
||||
return f"{value:.2f}"
|
||||
return str(value)
|
||||
|
||||
def compute_embedding(self, batch: Dict[str, Any]) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding vectors from feature batches using SBERT.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Extract feature dictionaries from batch
|
||||
2. Textualize each sample's features into natural language
|
||||
3. Encode textual descriptions using SBERT
|
||||
4. Return fixed-size embedding vectors
|
||||
|
||||
Args:
|
||||
batch: HuggingFace dataset batch from slicing (dataset[start:end])
|
||||
Format: {"features": [{"mean": 0.5, "std": 0.2, ...}, ...]}
|
||||
|
||||
Returns:
|
||||
Embedding array of shape (batch_size, embedding_dim)
|
||||
For all-MiniLM-L6-v2: (batch_size, 384)
|
||||
"""
|
||||
# Extract feature dictionaries from batch
|
||||
samples = batch["features"]
|
||||
|
||||
# Convert each sample's features to text
|
||||
text_samples = []
|
||||
for sample in samples:
|
||||
text_samples.append(self.textualize_sample(sample))
|
||||
|
||||
# Encode text descriptions using SBERT
|
||||
# Returns numpy array of shape (batch_size, 384)
|
||||
embeddings = self.model.encode(text_samples)
|
||||
|
||||
return embeddings
|
||||
|
||||
def extract_embeddings(
|
||||
self,
|
||||
data_root: str,
|
||||
batch_size: int = 32,
|
||||
label: str = "N1"
|
||||
) -> Dataset:
|
||||
"""
|
||||
Extract embeddings from all sessions under the data root directory.
|
||||
|
||||
Iterates through all user/session combinations, processes features
|
||||
in batches, and aggregates results with metadata. Filters by sleep
|
||||
stage label to focus on specific sleep stages.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
batch_size: Number of samples to process together.
|
||||
Larger = faster but more memory.
|
||||
32 is a good balance for most systems.
|
||||
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM")
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with columns:
|
||||
- user_id, session_id, idx, label (metadata)
|
||||
- embedding (384-dim vector from all-MiniLM-L6-v2)
|
||||
"""
|
||||
session_paths = self.discover_session_paths(data_root)
|
||||
print(f"[INFO] Discovered {len(session_paths)} sessions")
|
||||
|
||||
all_embeddings = []
|
||||
all_user_ids = []
|
||||
all_session_ids = []
|
||||
all_idxs = []
|
||||
all_labels = []
|
||||
|
||||
for user_id, session_id, session_path in session_paths:
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(session_path)
|
||||
# Shuffle dataset for randomness
|
||||
dataset = dataset.shuffle(seed=0)
|
||||
# Filter by sleep stage label
|
||||
dataset = dataset.filter(lambda x: x["label"] == label)
|
||||
num_samples = len(dataset)
|
||||
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
|
||||
|
||||
# Process in batches to manage memory
|
||||
for batch_start in range(0, num_samples, batch_size):
|
||||
batch_end = min(batch_start + batch_size, num_samples)
|
||||
|
||||
# Slice dataset to get batch
|
||||
batch = dataset[batch_start:batch_end]
|
||||
|
||||
# Compute embeddings
|
||||
embeddings = self.compute_embedding(batch)
|
||||
|
||||
# Collect embeddings and metadata
|
||||
for i in range(embeddings.shape[0]):
|
||||
all_embeddings.append(embeddings[i].tolist())
|
||||
all_user_ids.append(str(batch["user_id"][i]))
|
||||
all_session_ids.append(str(batch["session_id"][i]))
|
||||
all_idxs.append(int(batch["idx"][i]))
|
||||
all_labels.append(str(batch["label"][i]))
|
||||
|
||||
# Create HuggingFace Dataset
|
||||
result_dataset = Dataset.from_dict({
|
||||
"user_id": all_user_ids,
|
||||
"session_id": all_session_ids,
|
||||
"idx": all_idxs,
|
||||
"label": all_labels,
|
||||
"embedding": all_embeddings,
|
||||
})
|
||||
|
||||
print(f"[INFO] Total samples: {len(result_dataset)}")
|
||||
return result_dataset
|
||||
|
||||
def save_embeddings(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
output_dir: str
|
||||
) -> None:
|
||||
"""
|
||||
Save embeddings dataset to disk in HuggingFace format.
|
||||
|
||||
Args:
|
||||
dataset: HuggingFace Dataset containing embeddings and metadata
|
||||
output_dir: Directory path to save the dataset
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
dataset.save_to_disk(output_dir)
|
||||
|
||||
print(f"[DONE] Saved embeddings dataset: {output_dir}")
|
||||
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
|
||||
|
||||
@staticmethod
|
||||
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||
"""
|
||||
Load saved embeddings dataset from disk.
|
||||
|
||||
Args:
|
||||
embedding_dir: Directory path containing the saved HuggingFace dataset
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with embeddings and metadata
|
||||
"""
|
||||
dataset = load_from_disk(embedding_dir)
|
||||
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization Functions
|
||||
# =============================================================================
|
||||
|
||||
def reduce_to_2d_tsne(
|
||||
embeddings: np.ndarray,
|
||||
perplexity: float = 30.0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Reduce high-dimensional embeddings to 2D using t-SNE.
|
||||
|
||||
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
|
||||
dimensionality reduction technique that preserves local structure.
|
||||
Points that are similar in high dimensions stay close in 2D.
|
||||
|
||||
Args:
|
||||
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
|
||||
perplexity: t-SNE perplexity parameter (typically 5-50).
|
||||
Higher values consider more neighbors, creating smoother layouts.
|
||||
Rule of thumb: perplexity ~ sqrt(num_samples)
|
||||
|
||||
Returns:
|
||||
2D coordinates of shape (num_samples, 2)
|
||||
"""
|
||||
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
|
||||
|
||||
tsne = TSNE(
|
||||
n_components=2,
|
||||
random_state=0, # For reproducibility
|
||||
perplexity=perplexity,
|
||||
max_iter=1000, # Usually sufficient for convergence
|
||||
init='random',
|
||||
learning_rate='auto', # Let sklearn choose optimal learning rate
|
||||
)
|
||||
|
||||
return tsne.fit_transform(embeddings)
|
||||
|
||||
|
||||
def create_scatter_plot(
|
||||
coordinates: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
title: str,
|
||||
output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot with categorical coloring.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
labels: Category labels for each point (string array)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF vector format recommended)
|
||||
"""
|
||||
# Get unique labels for legend
|
||||
unique_labels = sorted(set(labels))
|
||||
|
||||
# Select colormap based on number of categories
|
||||
# tab10: 10 distinct colors, tab20: 20 distinct colors
|
||||
colormap = plt.cm.tab10 if len(unique_labels) <= 10 else plt.cm.tab20
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Plot each category separately for proper legend
|
||||
for idx, label in enumerate(unique_labels):
|
||||
mask = labels == label
|
||||
ax.scatter(
|
||||
coordinates[mask, 0],
|
||||
coordinates[mask, 1],
|
||||
c=[colormap(idx % 20)],
|
||||
s=15,
|
||||
label=label,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# Configure plot appearance
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
ax.legend(loc='best', markerscale=2)
|
||||
|
||||
# Save figure as vector PDF (scalable, ideal for publications)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
class CLI:
|
||||
"""
|
||||
Command-line interface for SBERT embedding extraction and visualization.
|
||||
|
||||
Provides two main commands:
|
||||
- extract: Generate embeddings from time series features
|
||||
- plot: Visualize embeddings with t-SNE
|
||||
"""
|
||||
|
||||
def extract(
|
||||
self,
|
||||
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
|
||||
out_dir: str = "./embeddings/W",
|
||||
batch_size: int = 32,
|
||||
label: str = "W"
|
||||
) -> None:
|
||||
"""
|
||||
Extract embeddings from time series data.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
out_dir: Output directory for HuggingFace dataset
|
||||
batch_size: Batch size for inference (default: 32)
|
||||
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM')
|
||||
"""
|
||||
embedder = SBERT()
|
||||
dataset = embedder.extract_embeddings(data_root, batch_size, label)
|
||||
embedder.save_embeddings(dataset, out_dir)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
emb_dir: str = "./embeddings/W",
|
||||
out_dir: str = "./plots/W",
|
||||
perplexity: float = 10.0,
|
||||
users: str = None,
|
||||
num_users: int = 0,
|
||||
labels: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize embeddings with t-SNE.
|
||||
|
||||
Args:
|
||||
emb_dir: Directory containing the HuggingFace embeddings dataset
|
||||
out_dir: Output directory for visualization plots (PDF)
|
||||
perplexity: t-SNE perplexity parameter (default: 10.0)
|
||||
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||
num_users: Include only first N users, 0 = all (default: 0)
|
||||
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||
"""
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Load saved embeddings dataset
|
||||
dataset = SBERT.load_embeddings(emb_dir)
|
||||
|
||||
# Apply user filtering
|
||||
if users:
|
||||
user_list = [u.strip() for u in users.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
|
||||
print(f"[INFO] Filtered to users: {user_list}")
|
||||
|
||||
elif num_users > 0:
|
||||
all_users = sorted(set(dataset["user_id"]))
|
||||
selected_users = all_users[:num_users]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||
|
||||
# Filter by sleep stage labels
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["label"] in label_list)
|
||||
print(f"[INFO] Filtered to labels: {label_list}")
|
||||
|
||||
print(f"[INFO] Total samples: {len(dataset)}")
|
||||
|
||||
# Extract embeddings as numpy array for t-SNE
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
|
||||
# Reduce to 2D with t-SNE
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||
|
||||
# Generate visualizations
|
||||
create_scatter_plot(
|
||||
coordinates_2d,
|
||||
np.array(dataset["label"]),
|
||||
"t-SNE Visualization (Colored by Sleep Stage)",
|
||||
os.path.join(out_dir, "tsne_by_label.pdf")
|
||||
)
|
||||
|
||||
create_scatter_plot(
|
||||
coordinates_2d,
|
||||
np.array(dataset["user_id"]),
|
||||
"t-SNE Visualization (Colored by User ID)",
|
||||
os.path.join(out_dir, "tsne_by_user.pdf")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(CLI)
|
||||
|
||||
@@ -0,0 +1,527 @@
|
||||
"""
|
||||
User Classification from SBERT Embeddings
|
||||
|
||||
This module evaluates how well SBERT embeddings capture user-specific patterns
|
||||
by training a simple linear classifier to predict user identity from embeddings.
|
||||
|
||||
Motivation:
|
||||
If embeddings contain user-distinguishing information, a classifier should be able
|
||||
to predict which user a time series belongs to. High accuracy suggests that the
|
||||
embeddings capture individual characteristics.
|
||||
|
||||
Experimental Design:
|
||||
- Task: Multi-class classification
|
||||
- Model: Random Forest with ensemble of decision trees
|
||||
- Captures non-linear relationships in embedding space
|
||||
- Provides feature importance scores for interpretability
|
||||
- Robust to overfitting through bagging and random feature selection
|
||||
- Split Strategy:
|
||||
1. Session-based: Train on session 1, test on session 2
|
||||
2. Random: Standard train/test split
|
||||
|
||||
Session-based split is more challenging and realistic because:
|
||||
- Tests whether user patterns are stable across different recording sessions
|
||||
- Avoids data leakage from same-session samples in train and test
|
||||
|
||||
Random Forest Advantages over Logistic Regression:
|
||||
- Handles non-linear decision boundaries
|
||||
- Feature importance reveals which embedding dimensions matter most
|
||||
- No assumption about data distribution
|
||||
- Naturally handles multi-class classification
|
||||
|
||||
Output:
|
||||
- Classification metrics
|
||||
- Confusion matrix visualization
|
||||
|
||||
Usage:
|
||||
python simple_user_classifier.py \\
|
||||
--embeddings ./embeddings \\
|
||||
--out_dir ./results
|
||||
|
||||
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||
Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from datasets import load_from_disk, Dataset
|
||||
from fire import Fire
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
f1_score,
|
||||
classification_report,
|
||||
confusion_matrix,
|
||||
silhouette_score,
|
||||
)
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading
|
||||
# =============================================================================
|
||||
|
||||
def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
|
||||
if not os.path.isdir(embedding_path):
|
||||
raise FileNotFoundError(
|
||||
f"Dataset directory not found: {embedding_path}. "
|
||||
"Ensure this is a valid HuggingFace dataset directory."
|
||||
)
|
||||
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(embedding_path)
|
||||
|
||||
return dataset
|
||||
|
||||
# =============================================================================
|
||||
# Data Splitting
|
||||
# =============================================================================
|
||||
|
||||
def split_by_session(
|
||||
features: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
session_ids: np.ndarray,
|
||||
train_session: str = "1",
|
||||
test_session: str = "2",
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
|
||||
"""
|
||||
Split data by recording session for temporal generalization evaluation.
|
||||
|
||||
This split strategy tests whether learned patterns generalize across time.
|
||||
Training on session 1 and testing on session 2 simulates real-world deployment
|
||||
where models must work on future recordings.
|
||||
|
||||
Args:
|
||||
features: Feature matrix of shape (num_samples, num_features)
|
||||
labels: Label array of shape (num_samples,)
|
||||
session_ids: Session identifier for each sample
|
||||
train_session: Session ID to use for training (default: "1")
|
||||
test_session: Session ID to use for testing (default: "2")
|
||||
|
||||
Returns:
|
||||
Tuple of (X_train, X_test, y_train, y_test, split_description)
|
||||
"""
|
||||
# Create boolean masks for train and test sets
|
||||
train_mask = session_ids == train_session
|
||||
test_mask = session_ids == test_session
|
||||
|
||||
# Apply masks to create train/test splits
|
||||
X_train = features[train_mask]
|
||||
X_test = features[test_mask]
|
||||
y_train = labels[train_mask]
|
||||
y_test = labels[test_mask]
|
||||
|
||||
split_description = f"session({train_session}->train, {test_session}->test)"
|
||||
|
||||
return X_train, X_test, y_train, y_test, split_description
|
||||
|
||||
|
||||
def split_random(
|
||||
features: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
test_size: float = 0.2,
|
||||
random_state: int = 0,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
|
||||
"""
|
||||
Split data randomly with stratification for in-distribution evaluation.
|
||||
|
||||
Stratification ensures each class has proportional representation in
|
||||
both train and test sets, preventing class imbalance issues.
|
||||
|
||||
Args:
|
||||
features: Feature matrix of shape (num_samples, num_features)
|
||||
labels: Label array of shape (num_samples,)
|
||||
test_size: Fraction of data to use for testing (default: 0.2)
|
||||
random_state: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Tuple of (X_train, X_test, y_train, y_test, split_description)
|
||||
"""
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
features,
|
||||
labels,
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
stratify=labels,
|
||||
)
|
||||
|
||||
split_description = "random"
|
||||
|
||||
return X_train, X_test, y_train, y_test, split_description
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Training and Evaluation
|
||||
# =============================================================================
|
||||
|
||||
def create_classifier_pipeline(
|
||||
n_estimators: int = 200,
|
||||
max_depth: int = None,
|
||||
min_samples_split: int = 2,
|
||||
min_samples_leaf: int = 1,
|
||||
random_state: int = 0,
|
||||
) -> Pipeline:
|
||||
"""
|
||||
Create a scikit-learn pipeline for user classification using Random Forest.
|
||||
|
||||
Pipeline Architecture:
|
||||
----------------------
|
||||
1. StandardScaler: Z-score normalization of features
|
||||
- Centers features (mean=0) and scales to unit variance (std=1)
|
||||
- While Random Forest is scale-invariant, scaling helps with
|
||||
consistent feature importance interpretation
|
||||
|
||||
2. RandomForestClassifier: Ensemble of decision trees
|
||||
- Builds multiple decision trees on random subsets of data (bagging)
|
||||
- Each tree uses random subset of features at each split
|
||||
- Final prediction is majority vote across all trees
|
||||
- Provides feature_importances_ for interpretability
|
||||
|
||||
Random Forest Hyperparameters:
|
||||
------------------------------
|
||||
- n_estimators: Number of trees in the forest (more = better but slower)
|
||||
- max_depth: Maximum tree depth (None = expand until pure leaves)
|
||||
- min_samples_split: Minimum samples to split internal node
|
||||
- min_samples_leaf: Minimum samples required at leaf node
|
||||
|
||||
Args:
|
||||
n_estimators: Number of trees (default: 200)
|
||||
max_depth: Maximum depth of trees (default: None, fully grown)
|
||||
min_samples_split: Min samples for splitting (default: 2)
|
||||
min_samples_leaf: Min samples at leaf (default: 1)
|
||||
random_state: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Configured sklearn Pipeline ready for .fit() and .predict()
|
||||
"""
|
||||
pipeline = Pipeline([
|
||||
# Step 1: Feature normalization
|
||||
("scaler", StandardScaler(
|
||||
with_mean=True, # Subtract mean (center the data)
|
||||
with_std=True, # Divide by standard deviation (scale to unit variance)
|
||||
)),
|
||||
|
||||
# Step 2: Random Forest classification
|
||||
("classifier", RandomForestClassifier(
|
||||
n_estimators=n_estimators, # Number of trees in the forest
|
||||
max_depth=max_depth, # Maximum depth (None = unlimited)
|
||||
min_samples_split=min_samples_split,
|
||||
min_samples_leaf=min_samples_leaf,
|
||||
n_jobs=-1, # Use all CPU cores for parallel tree building
|
||||
random_state=random_state, # For reproducibility
|
||||
class_weight="balanced", # Handle class imbalance automatically
|
||||
oob_score=True, # Enable out-of-bag error estimation
|
||||
)),
|
||||
])
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def evaluate_classifier(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
) -> Tuple[float, float, str, np.ndarray, list]:
|
||||
"""
|
||||
Compute classification metrics and confusion matrix.
|
||||
|
||||
Metrics Computed:
|
||||
- Accuracy: Overall fraction of correct predictions
|
||||
- Macro F1: Average F1 across all classes
|
||||
- Per-class report: Precision, recall, F1 for each user
|
||||
- Confusion matrix: Detailed breakdown of predictions vs ground truth
|
||||
|
||||
Args:
|
||||
y_true: Ground truth labels
|
||||
y_pred: Predicted labels
|
||||
|
||||
Returns:
|
||||
Tuple of (accuracy, f1, classification_report, confusion_matrix, class_labels)
|
||||
"""
|
||||
# Compute scalar metrics
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
f1 = f1_score(y_true, y_pred, average='macro') # Multiclass: use macro-averaged F1
|
||||
|
||||
# Generate detailed per-class report
|
||||
report = classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
# Compute confusion matrix: Get all unique classes from both true and predicted labels
|
||||
class_labels = sorted(set(y_true) | set(y_pred))
|
||||
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
|
||||
|
||||
return accuracy, f1, report, cm, class_labels
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization
|
||||
# =============================================================================
|
||||
|
||||
def save_confusion_matrix_plot(
|
||||
confusion_mat: np.ndarray,
|
||||
class_labels: list,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a confusion matrix heatmap visualization.
|
||||
|
||||
The confusion matrix shows:
|
||||
- Rows: True class labels
|
||||
- Columns: Predicted class labels
|
||||
- Cell values: Count of samples with that (true, predicted) combination
|
||||
- Diagonal: Correct predictions
|
||||
- Off-diagonal: Misclassifications
|
||||
|
||||
Args:
|
||||
confusion_mat: Square matrix of shape (num_classes, num_classes)
|
||||
class_labels: List of class names for axis labels
|
||||
output_path: File path to save the plot (PDF format recommended)
|
||||
"""
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Plot heatmap using imshow
|
||||
im = ax.imshow(confusion_mat, aspect="auto", cmap="Blues")
|
||||
|
||||
# Add colorbar to show value scale
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
# Set labels and title
|
||||
ax.set_title("Confusion Matrix (User Classification)")
|
||||
ax.set_xlabel("Predicted User")
|
||||
ax.set_ylabel("True User")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format="pdf", bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved confusion matrix: {output_path}")
|
||||
|
||||
|
||||
def save_metrics_report(
|
||||
output_path: str,
|
||||
split_description: str,
|
||||
train_size: int,
|
||||
test_size: int,
|
||||
accuracy: float,
|
||||
macro_f1: float,
|
||||
classification_report_str: str,
|
||||
oob_score: float = None,
|
||||
silhouette: float = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save classification metrics to a text file.
|
||||
|
||||
Writes summary statistics (split strategy, sizes, scores) and detailed
|
||||
per-class classification report to a text file.
|
||||
|
||||
Args:
|
||||
output_path: File path to save the metrics report
|
||||
split_description: Description of train/test split strategy
|
||||
train_size: Number of training samples
|
||||
test_size: Number of test samples
|
||||
accuracy: Overall classification accuracy
|
||||
macro_f1: Macro-averaged F1 score
|
||||
classification_report_str: Detailed per-class report string
|
||||
oob_score: Out-of-bag score from Random Forest (optional)
|
||||
silhouette: Silhouette score for embedding quality (optional)
|
||||
"""
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
# Write summary statistics
|
||||
f.write(f"split_used : {split_description}\n")
|
||||
f.write(f"train_size : {train_size}\n")
|
||||
f.write(f"test_size : {test_size}\n")
|
||||
|
||||
# Silhouette score (embedding quality metric)
|
||||
if silhouette is not None:
|
||||
f.write(f"silhouette : {silhouette:.4f}\n")
|
||||
|
||||
f.write(f"accuracy : {accuracy:.4f}\n")
|
||||
f.write(f"macro_f1 : {macro_f1:.4f}\n")
|
||||
|
||||
# Add OOB score if available (Random Forest specific)
|
||||
if oob_score is not None:
|
||||
f.write(f"oob_score : {oob_score:.4f}\n")
|
||||
|
||||
f.write("\n")
|
||||
|
||||
# Write detailed per-class report
|
||||
f.write(classification_report_str)
|
||||
|
||||
print(f"[DONE] Saved metrics report: {output_path}")
|
||||
|
||||
|
||||
def save_feature_importance_plot(
|
||||
feature_importances: np.ndarray,
|
||||
output_path: str,
|
||||
top_k: int = 50,
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a horizontal bar chart of top-k most important features.
|
||||
|
||||
Visualizes which embedding dimensions contribute most to user classification.
|
||||
Higher importance indicates that dimension better distinguishes between users.
|
||||
|
||||
Args:
|
||||
feature_importances: Array of feature importance scores from Random Forest
|
||||
output_path: File path to save the plot (PDF format recommended)
|
||||
top_k: Number of top features to display (default: 50)
|
||||
"""
|
||||
# Get indices of top-k most important features
|
||||
top_indices = np.argsort(feature_importances)[::-1][:top_k]
|
||||
top_importances = feature_importances[top_indices]
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# Create horizontal bar chart
|
||||
y_positions = np.arange(len(top_indices))
|
||||
ax.barh(y_positions, top_importances, color="steelblue", alpha=0.8)
|
||||
|
||||
# Set labels
|
||||
ax.set_yticks(y_positions)
|
||||
ax.set_yticklabels([f"emb_{i}" for i in top_indices], fontsize=8)
|
||||
ax.invert_yaxis() # Highest importance at top
|
||||
|
||||
ax.set_xlabel("Feature Importance (Mean Decrease in Impurity)")
|
||||
ax.set_ylabel("Embedding Dimension")
|
||||
ax.set_title(f"Top {top_k} Most Important Embedding Dimensions")
|
||||
|
||||
# Save as vector PDF
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format="pdf", bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved feature importance plot: {output_path}")
|
||||
|
||||
|
||||
def save_feature_importance_csv(
|
||||
feature_importances: np.ndarray,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Save feature importance scores to CSV file with ranking.
|
||||
|
||||
Creates a CSV with columns: rank, feature, importance
|
||||
Features are sorted by importance in descending order.
|
||||
|
||||
Args:
|
||||
feature_importances: Array of feature importance scores from Random Forest
|
||||
output_path: File path to save the CSV file
|
||||
"""
|
||||
importance_df = pd.DataFrame({
|
||||
"feature": [f"emb_{i}" for i in range(len(feature_importances))],
|
||||
"importance": feature_importances,
|
||||
})
|
||||
importance_df = importance_df.sort_values("importance", ascending=False)
|
||||
importance_df = importance_df.reset_index(drop=True)
|
||||
importance_df["rank"] = range(1, len(importance_df) + 1)
|
||||
importance_df = importance_df[["rank", "feature", "importance"]]
|
||||
importance_df.to_csv(output_path, index=False)
|
||||
|
||||
print(f"[DONE] Saved feature importance CSV: {output_path}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
def main(
|
||||
embeddings: str = "./embeddings/REM",
|
||||
out_dir: str = "./results/REM",
|
||||
test_size: float = 0.2,
|
||||
n_estimators: int = 200,
|
||||
max_depth: int = None,
|
||||
) -> None:
|
||||
# Create output directory
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
print(f"[INFO] Loading embeddings from: {embeddings}")
|
||||
dataset = load_embeddings_with_metadata(embeddings)
|
||||
|
||||
X = np.array(dataset["embedding"], dtype=np.float32)
|
||||
y = np.array([str(uid) for uid in dataset["user_id"]])
|
||||
session_ids = np.array([str(sid) for sid in dataset["session_id"]])
|
||||
|
||||
print(f"[INFO] Loaded {len(y)} samples with {X.shape[1]}-dimensional embeddings")
|
||||
print(f"[INFO] Number of unique users: {len(np.unique(y))}")
|
||||
|
||||
has_session_1 = (session_ids == "1").sum() > 0
|
||||
has_session_2 = (session_ids == "2").sum() > 0
|
||||
|
||||
if has_session_1 and has_session_2:
|
||||
X_train, X_test, y_train, y_test, split_desc = split_by_session(
|
||||
X, y, session_ids
|
||||
)
|
||||
else:
|
||||
print("[WARN] Missing session data, falling back to random split")
|
||||
X_train, X_test, y_train, y_test, split_desc = split_random(
|
||||
X, y, test_size
|
||||
)
|
||||
split_desc = "random(fallback)"
|
||||
|
||||
print(f"[INFO] Split: {split_desc}")
|
||||
print(f"[INFO] Train size: {len(y_train)}, Test size: {len(y_test)}")
|
||||
silhouette_avg = silhouette_score(X, y)
|
||||
print(f"[INFO] Silhouette Score (user clusters): {silhouette_avg:.4f}")
|
||||
print("[INFO] Training Random Forest classifier...")
|
||||
print(f"[INFO] Hyperparameters: n_estimators={n_estimators}, max_depth={max_depth}")
|
||||
|
||||
classifier = create_classifier_pipeline(
|
||||
n_estimators=n_estimators,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
classifier.fit(X_train, y_train)
|
||||
|
||||
y_pred = classifier.predict(X_test)
|
||||
rf_model = classifier.named_steps["classifier"]
|
||||
oob_score = rf_model.oob_score_ if hasattr(rf_model, "oob_score_") else None
|
||||
feature_importances = rf_model.feature_importances_
|
||||
|
||||
print("[INFO] Evaluating classifier performance...")
|
||||
accuracy, macro_f1, report, cm, classes = evaluate_classifier(y_test, y_pred)
|
||||
metrics_path = os.path.join(out_dir, "user_cls_metrics.txt")
|
||||
save_metrics_report(
|
||||
output_path=metrics_path,
|
||||
split_description=split_desc,
|
||||
train_size=len(y_train),
|
||||
test_size=len(y_test),
|
||||
accuracy=accuracy,
|
||||
macro_f1=macro_f1,
|
||||
classification_report_str=report,
|
||||
oob_score=oob_score,
|
||||
silhouette=silhouette_avg,
|
||||
)
|
||||
|
||||
confusion_path = os.path.join(out_dir, "user_cls_confusion.pdf")
|
||||
save_confusion_matrix_plot(cm, classes, confusion_path)
|
||||
|
||||
importance_plot_path = os.path.join(out_dir, "feature_importance.pdf")
|
||||
save_feature_importance_plot(feature_importances, importance_plot_path)
|
||||
|
||||
importance_csv_path = os.path.join(out_dir, "feature_importance.csv")
|
||||
save_feature_importance_csv(feature_importances, importance_csv_path)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("RANDOM FOREST CLASSIFICATION RESULTS")
|
||||
print("=" * 50)
|
||||
print(f"Split Strategy : {split_desc}")
|
||||
print(f"Silhouette : {silhouette_avg:.4f}")
|
||||
print(f"Accuracy : {accuracy:.4f}")
|
||||
print(f"Macro F1 : {macro_f1:.4f}")
|
||||
if oob_score is not None:
|
||||
print(f"OOB Score : {oob_score:.4f}")
|
||||
print(f"Top 5 Important Features:")
|
||||
top5_idx = np.argsort(feature_importances)[::-1][:5]
|
||||
for rank, idx in enumerate(top5_idx, 1):
|
||||
print(f" {rank}. emb_{idx}: {feature_importances[idx]:.4f}")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user