8 Commits

Author SHA1 Message Date
Hyungjun Yoon
c4bbbf702d implemented globem processing 2026-03-18 21:30:42 +09:00
Hyungjun Yoon
ccb88d4eef implemented 6 random baselines 2026-03-12 20:41:34 +09:00
Hyungjun Yoon
95ddd935f7 fixed some bugs in run 2026-03-10 19:54:37 +09:00
Hyungjun Yoon
d10b70dc20 cleaned code 2026-03-09 00:04:01 +09:00
Hyungjun Yoon
3f3db31d25 working code with queue update 2026-03-08 23:55:32 +09:00
Hynugjun Yoon
31f4af7106 implemented new version of llm initialization 2026-02-12 17:42:09 +09:00
Hyungjun Yoon
8a26346b5e updated model to take hf format 2026-02-12 16:52:46 +09:00
Hyungjun Yoon
a7c8e43f89 updated model to huggingface framework 2026-02-12 14:42:25 +09:00
85 changed files with 2000 additions and 11477 deletions

1
.gitignore vendored
View File

@@ -3,6 +3,7 @@
*.csv
*.arrow
*.json
temp*/
# Byte-compiled / optimized / DLL files
__pycache__/

80
README.md Normal file
View File

@@ -0,0 +1,80 @@
## Setup
```bash
conda create -n tsllmpers python=3.12
conda activate tsllmpers
python -m pip install -r requirements.txt
```
## Preprocessing
Raw data must be preprocessed into HuggingFace `datasets` format before running.
A preprocessing script is provided for the SleepEDF dataset:
```bash
python preprocess/preprocess_SleepEDF.py \
--path /path/to/SleepEDF/raw/sleep-cassette/ \
--out_dir /path/to/output/processed_SleepEDF \
--num_workers 32
```
The output directory will have the following structure:
```
SleepEDF_new/
task_metadata.json # task description, class definitions, data/feature info
user_metadata.json # per-user metadata (age, sex)
00/ # user folder (HuggingFace Dataset saved with save_to_disk)
01/
...
```
Each user dataset contains the columns: `user_id` (str), `label` (str), `features` (dict of floats), and `data` (dict of raw signal arrays).
To preprocess a different dataset, write a similar script that produces the same output structure.
## Running
```bash
python run.py --config_path config/test.yaml
```
## Config
All configuration is in a single YAML file. Example (`config/test.yaml`):
```yaml
log_path: ./temp
data_path: /path/to/data/processed_SleepEDF
target_user: "00"
queue_size: 5
num_shot: 1
model_paths:
- ollama:url:hostname:11437/gpt-oss:20b
- ollama:url:hostname:11438/gpt-oss:20b
vocab_size: 200064
```
| Key | Description |
|---|---|
| `log_path` | Directory for logs and results (timestamped subfolder created automatically) |
| `data_path` | Path to the preprocessed dataset directory |
| `target_user` | User ID to evaluate on; all other users become source (ICL examples) |
| `queue_size` | Number of example sets to maintain in the queue |
| `num_shot` | Number of examples per class in each example set |
| `model_paths` | List of Ollama model endpoints in `ollama:url:host:port/model` format |
| `vocab_size` | Vocabulary size of the model (used for self-certainty scoring) |
## Ollama
The model backend is [Ollama](https://ollama.com). You need one or more Ollama servers running and accessible over HTTP.
Example scripts for managing multiple Ollama instances on a multi-GPU machine are provided in `utils/` for reference. These are environment-specific -- adapt the ports, model paths, and GPU assignments to your own setup:
```bash
# Reference only -- edit before using
bash utils/launch_ollamas.sh # start servers in tmux sessions
bash utils/kill_ollamas.sh # stop all servers
```
The model path format in the config is `ollama:url:<host>:<port>/<model_name>`.

File diff suppressed because one or more lines are too long

View File

@@ -1,96 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "a0874e1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Correct: 1038, Total: 2149, Accuracy: 0.4830153559795254\n"
]
}
],
"source": [
"from glob import glob\n",
"import os\n",
"\n",
"correct = 0\n",
"total = 0\n",
"summary_paths = glob(\"/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF/*/*/*/summary.txt\")\n",
"for summary_path in summary_paths:\n",
" with open(summary_path, \"r\") as f:\n",
" summary = f.read()\n",
" if summary:\n",
" answer = summary.split(\"Answer: \")[-1].split(\" (Ground truth: \")[0]\n",
" ground_truth = summary.split(\" (Ground truth: \")[-1].split(\")\")[0]\n",
" if answer == ground_truth:\n",
" correct += 1\n",
" total += 1\n",
"print(f\"Correct: {correct}, Total: {total}, Accuracy: {correct/total}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f78ffc6f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Correct: 932, Total: 2260, Accuracy: 0.41238938053097346\n"
]
}
],
"source": [
"correct = 0\n",
"total = 0\n",
"summary_paths = glob(\"/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF_out_random/*/*/*/summary.txt\")\n",
"for summary_path in summary_paths:\n",
" with open(summary_path, \"r\") as f:\n",
" summary = f.read()\n",
" if summary:\n",
" answer = summary.split(\"Answer: \")[-1].split(\" (Ground truth: \")[0]\n",
" ground_truth = summary.split(\" (Ground truth: \")[-1].split(\")\")[0]\n",
" if answer == ground_truth:\n",
" correct += 1\n",
" total += 1\n",
"print(f\"Correct: {correct}, Total: {total}, Accuracy: {correct/total}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6872381f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tsllmpers",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,580 +0,0 @@
"""
Chronos-2 Time Series Embedding Extraction and Visualization Pipeline
This module provides functionality to
1. Extract embeddings from multivariate time series using Chronos-2 foundation model
2. Visualize embeddings using dimensionality reduction (t-SNE)
Chronos-2 Overview:
Chronos-2 is a time series foundation model developed by Amazon that converts
time series forecasting into a language modeling task. It tokenizes time series
values and generates probabilistic forecasts using a transformer architecture.
Key Features:
- Zero-shot forecasting: Works on unseen time series without fine-tuning
- Probabilistic predictions: Outputs quantile forecasts (e.g., 10%, 50%, 90%)
- Multivariate support: Can process multiple channels simultaneously
Embedding Strategy:
We use Chronos-2's internal encoder hidden states as embeddings, which is the
recommended approach for representation learning. The encoder captures rich
temporal patterns through self-attention mechanisms.
"encoder" : Uses encoder hidden states directly
- More informative representation of input characteristics
- Captures learned temporal patterns from pre-training
Usage:
# Extract embeddings
python gen_plot.py extract --data_root /path/to/data --out_dir ./embeddings
# Visualize with t-SNE
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
import sys
from glob import glob
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt
import torch
from datasets import load_from_disk, Dataset
from chronos import BaseChronosPipeline, Chronos2Pipeline
from fire import Fire
from sklearn.manifold import TSNE
# Add parent directories to path for importing metadata loader
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../.."))
from preprocess.sleepedf_metadata import SleepEDFMetadata, DEFAULT_METADATA_PATH
# pipeline: Chronos2Pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cuda")
# =============================================================================
# Constants
# =============================================================================
# EEG channel names used in Sleep-EDF dataset
# Fpz-Cz: Frontal-Central electrode pair
# Pz-Oz: Parietal-Occipital electrode pair
EEG_CHANNEL_1 = "EEG Fpz-Cz"
EEG_CHANNEL_2 = "EEG Pz-Oz"
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class Chronos_2_Embedder:
"""
Extracts fixed-dimensional embeddings from multivariate time series using Chronos-2.
Uses the internal encoder hidden states from Chronos-2's transformer.
This directly accesses the model's learned features for representation learning.
Architecture:
Input Time Series → Patching → Encoder (6 layers) → Hidden States → Pooling → Embedding
Processing Pipeline:
1. Instance normalization (z-score per series)
2. Patching (splits series into fixed-size patches)
3. Patch embedding (linear projection to d_model dimensions)
4. Transformer encoder (6 layers of self-attention)
5. Output: hidden states of shape (n_variates, num_patches + 2, d_model)
- +2 for [REG] token and masked output patch token
Pooling Strategies:
- mean: Average across all patch tokens (excluding special tokens)
- cls: Use the [REG] token embedding (similar to BERT's [CLS])
Attributes:
pipeline: Chronos-2 model pipeline for inference
pooling_strategy: How to pool encoder hidden states ("mean" or "cls")
"""
def __init__(
self,
model_name: str = "amazon/chronos-2",
device_map: str = None,
pooling_strategy: str = "mean",
variate_fusion: str = "concat",
):
"""
Initialize the Chronos-2 embedder.
Args:
model_name: HuggingFace model name or local path
device_map: Device placement ("cuda", "cpu", or None for auto)
pooling_strategy: "mean" (average all patches) or "cls" (use [REG] token only)
"""
self.pooling_strategy = pooling_strategy
if variate_fusion not in ["concat", "mean"]:
raise ValueError(
f"Invalid variate_fusion: {variate_fusion}. Use 'concat' (1024) or 'mean' (512)."
)
self.variate_fusion = variate_fusion
if device_map is None:
device_map = "cuda" if torch.cuda.is_available() else "cpu"
self.device_map = device_map
self.device = torch.device(
"cuda" if device_map.startswith("cuda") and torch.cuda.is_available() else "cpu"
)
# Load pre-trained Chronos-2 model
print(f"[INFO] Loading Chronos-2 model: {model_name}")
print(f"[INFO] Device: {device_map}")
print(f"[INFO] Pooling strategy: {pooling_strategy}")
print(f"[INFO] Variate fusion: {variate_fusion}")
self.pipeline: Chronos2Pipeline = Chronos2Pipeline.from_pretrained(
model_name,
device_map=device_map
)
@staticmethod
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
"""
Discover all user/session directories under data_root.
Uses glob pattern matching for cleaner directory traversal.
Expected structure: data_root/user_id/session_id/
Returns:
List of (user_id, session_id, session_path) tuples
"""
discovered_paths = []
# Use glob to find all session directories (2 levels deep)
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
if not os.path.isdir(session_path):
continue
# Extract user_id and session_id from path
session_id = os.path.basename(session_path)
user_id = os.path.basename(os.path.dirname(session_path))
discovered_paths.append((user_id, session_id, session_path))
return discovered_paths
@torch.no_grad()
def compute_embedding(self, batch: Dict[str, Any]) -> np.ndarray:
"""
Generate embedding vectors using Chronos-2's internal encoder hidden states.
Processing Pipeline:
1. Parse batch to extract 2-channel EEG time series
2. Format for Chronos-2 input (B, V, L)
3. Call pipeline.embed() to get encoder hidden states
4. Pool hidden states to get fixed-size embedding
Args:
batch: HuggingFace dataset batch from slicing (dataset[start:end])
Format: {"data": [{"EEG Fpz-Cz": [...], "EEG Pz-Oz": [...]}, ...]}
Returns:
Embedding array of shape (batch_size, embedding_dim)
For Chronos-2 with d_model=512:
- pooling (mean/cls) produces (B, 512) per variate
- variate_fusion='concat': (B, 2*512) = (B, 1024) for 2 channels
- variate_fusion='mean' : (B, 512) by averaging across variates
"""
# =====================================================================
# Step 1: Parse batch to multivariate time series array
# =====================================================================
samples = batch["data"]
channel_1 = np.stack([np.asarray(s[EEG_CHANNEL_1], dtype=np.float32) for s in samples])
channel_2 = np.stack([np.asarray(s[EEG_CHANNEL_2], dtype=np.float32) for s in samples])
timeseries = np.stack([channel_1, channel_2], axis=-1) # (B, 3000, 2)
# =====================================================================
# Step 2: Format for Chronos-2 input
# =====================================================================
# Chronos-2 expects (batch, n_variates, seq_length)
x_input = np.transpose(timeseries, (0, 2, 1)).astype(np.float32) # (B, 2, 3000)
# =====================================================================
# Step 3: Get encoder embeddings using pipeline.embed()
# =====================================================================
# embed() returns:
# - embeddings: list of tensors, each (n_variates, num_patches + 2, d_model)
# - loc_scale: list of tuples (loc, scale) for denormalization
embeddings_list, loc_scale_list = self.pipeline.embed(x_input)
# =====================================================================
# Step 4: Pool hidden states to get fixed-size embedding
# =====================================================================
all_embeddings = []
for emb in embeddings_list:
# emb shape: (n_variates, num_patches + 2, d_model) = (2, N+2, 512)
if self.pooling_strategy == "cls":
# Use the [REG] token (first token) as the embedding
# This is similar to BERT's [CLS] token approach
pooled = emb[:, 0, :] # (n_variates, d_model) = (2, 512)
else: # "mean" pooling (default)
# Average across all patch tokens (excluding special tokens)
# Skip first token ([REG]) and last token (masked output patch)
pooled = emb[:, 1:-1, :].mean(dim=1) # (n_variates, d_model) = (2, 512)
if self.variate_fusion == "mean":
# Fuse variates by averaging their pooled representations
# (2, 512) -> (512,)
pooled_flat = pooled.mean(dim=0)
else:
# Keep each variate's representation and concatenate
# (2, 512) -> (1024,)
pooled_flat = pooled.reshape(-1)
all_embeddings.append(pooled_flat.cpu().numpy())
return np.stack(all_embeddings, axis=0).astype(np.float32)
def extract_embeddings(
self,
data_root: str,
batch_size: int = 32,
metadata_path: str = DEFAULT_METADATA_PATH,
) -> Dataset:
"""
Extract embeddings from all sessions under the data root directory.
Iterates through all user/session combinations, processes time series
in batches, and aggregates results with metadata.
Args:
data_root: Root directory containing user/session subfolders
batch_size: Number of samples to process together.
Larger = faster but more memory.
32 is a good balance for most GPUs.
metadata_path: Path to SC-subjects.xls for gender/age info
Returns:
HuggingFace Dataset with columns:
- user_id, session_id, idx, label (metadata)
- gender, age (demographic metadata)
- embedding (vector; dim depends on variate_fusion: 1024 for concat, 512 for mean)
"""
session_paths = self.discover_session_paths(data_root)
print(f"[INFO] Discovered {len(session_paths)} sessions")
# Load metadata for gender/age information
try:
metadata = SleepEDFMetadata(metadata_path)
has_metadata = True
except FileNotFoundError:
print(f"[WARNING] Metadata file not found: {metadata_path}")
print(f"[WARNING] Gender/age information will not be available")
metadata = None
has_metadata = False
all_embeddings = []
all_user_ids = []
all_session_ids = []
all_idxs = []
all_labels = []
all_genders = []
all_ages = []
for user_id, session_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# shuffle dataset
dataset = dataset.shuffle(seed=0)
num_samples = len(dataset)
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Process in batches to manage memory
for batch_start in range(0, num_samples, batch_size):
batch_end = min(batch_start + batch_size, num_samples)
# Slice dataset to get batch
batch = dataset[batch_start:batch_end]
# Compute embeddings
embeddings = self.compute_embedding(batch)
# Collect embeddings and metadata
for i in range(embeddings.shape[0]):
user_id_str = str(batch["user_id"][i])
all_embeddings.append(embeddings[i].tolist())
all_user_ids.append(user_id_str)
all_session_ids.append(str(batch["session_id"][i]))
all_idxs.append(int(batch["idx"][i]))
all_labels.append(str(batch["label"][i]))
# Add demographic metadata
if has_metadata:
info = metadata.get_info(user_id_str)
if info:
all_genders.append(info['gender'])
all_ages.append(info['age'])
else:
all_genders.append('Unknown')
all_ages.append(-1)
else:
all_genders.append('Unknown')
all_ages.append(-1)
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
"session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"gender": all_genders,
"age": all_ages,
"embedding": all_embeddings,
})
print(f"[INFO] Total samples: {len(result_dataset)}")
if has_metadata:
gender_counts = {}
for g in all_genders:
gender_counts[g] = gender_counts.get(g, 0) + 1
print(f"[INFO] Gender distribution in extracted data: {gender_counts}")
return result_dataset
def save_embeddings(
self,
dataset: Dataset,
output_dir: str
) -> None:
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
dimensionality reduction technique that preserves local structure.
Points that are similar in high dimensions stay close in 2D.
Args:
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity parameter (typically 5-50).
Higher values consider more neighbors, creating smoother layouts.
Rule of thumb: perplexity ~ sqrt(num_samples)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0, # For reproducibility
perplexity=perplexity,
max_iter=1000, # Usually sufficient for convergence
init='random',
learning_rate='auto', # Let sklearn choose optimal learning rate
)
return tsne.fit_transform(embeddings)
def create_scatter_plot(
coordinates: np.ndarray,
labels: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot with categorical coloring.
Args:
coordinates: 2D array of shape (num_points, 2)
labels: Category labels for each point (string array)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Get unique labels for legend
unique_labels = sorted(set(labels))
# Select colormap based on number of categories
# tab10: 10 distinct colors, tab20: 20 distinct colors
colormap = plt.cm.tab10 if len(unique_labels) <= 10 else plt.cm.tab20
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for proper legend
for idx, label in enumerate(unique_labels):
mask = labels == label
ax.scatter(
coordinates[mask, 0],
coordinates[mask, 1],
c=[colormap(idx % 20)],
s=15,
label=label,
alpha=0.7,
)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(loc='best', markerscale=2)
# Save figure as vector PDF (scalable, ideal for publications)
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
def extract(
self,
data_root: str,
out_dir: str,
model: str = "amazon/chronos-2",
batch_size: int = 32,
pooling: str = "mean",
variate_fusion: str = "concat",
) -> None:
"""
Extract embeddings from time series data.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for HuggingFace dataset
model: Chronos-2 model name or path (default: amazon/chronos-2)
batch_size: Batch size for inference (default: 32)
pooling: Pooling strategy - 'mean' or 'cls' (default: mean)
variate_fusion: How to combine channels - 'concat' (1024) or 'mean' (512)
"""
# Validate pooling argument
if pooling not in ["mean", "cls"]:
raise ValueError(f"Invalid pooling strategy: {pooling}. Use 'mean' or 'cls'.")
if variate_fusion not in ["concat", "mean"]:
raise ValueError(
f"Invalid variate_fusion: {variate_fusion}. Use 'concat' (1024) or 'mean' (512)."
)
# Initialize embedder
embedder = Chronos_2_Embedder(
model_name=model,
pooling_strategy=pooling,
variate_fusion=variate_fusion,
)
# Extract and save embeddings
dataset = embedder.extract_embeddings(data_root, batch_size)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str,
out_dir: str,
perplexity: float = 30.0,
users: str = None,
num_users: int = 0,
labels: str = None,
gender: str = None,
metadata_path: str = DEFAULT_METADATA_PATH,
) -> None:
"""
Visualize embeddings with t-SNE.
Args:
emb_dir: Directory containing the HuggingFace embeddings dataset
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 30.0)
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
gender: Filter by gender ('M', 'F', or None for all)
metadata_path: Path to metadata file for gender lookup (if not in dataset)
"""
os.makedirs(out_dir, exist_ok=True)
# Load saved embeddings dataset
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by gender
if gender:
dataset = dataset.filter(lambda x: x.get("gender", "Unknown") == gender)
print(f"[INFO] Filtered to gender: {gender}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Print gender distribution
if "gender" in dataset.column_names:
gender_counts = {}
for g in dataset["gender"]:
gender_counts[g] = gender_counts.get(g, 0) + 1
print(f"[INFO] Gender distribution: {gender_counts}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations
create_scatter_plot(
coordinates_2d,
np.array(dataset["label"]),
"t-SNE Visualization (Colored by Sleep Stage)",
os.path.join(out_dir, "tsne_by_label.pdf")
)
create_scatter_plot(
coordinates_2d,
np.array(dataset["user_id"]),
"t-SNE Visualization (Colored by User ID)",
os.path.join(out_dir, "tsne_by_user.pdf")
)
if __name__ == "__main__":
Fire(CLI)

View File

@@ -1,507 +0,0 @@
"""
User Classification from Time Series Embeddings
This module evaluates how well time series embeddings capture user-specific patterns
by training a simple linear classifier to predict user identity from embeddings.
Motivation:
If embeddings contain user-distinguishing information, a classifier should be able
to predict which user a time series belongs to. High accuracy suggests that the
embeddings capture individual characteristics.
Experimental Design:
- Task: Multi-class classification
- Model: Random Forest with ensemble of decision trees
- Captures non-linear relationships in embedding space
- Provides feature importance scores for interpretability
- Robust to overfitting through bagging and random feature selection
- Split Strategy:
1. Session-based: Train on session 1, test on session 2
2. Random: Standard train/test split
Session-based split is more challenging and realistic because:
- Tests whether user patterns are stable across different recording sessions
- Avoids data leakage from same-session samples in train and test
Random Forest Advantages over Logistic Regression:
- Handles non-linear decision boundaries
- Feature importance reveals which embedding dimensions matter most
- No assumption about data distribution
- Naturally handles multi-class classification
Output:
- Classification metrics
- Confusion matrix visualization
Usage:
python simple_user_classifier.py \\
--embeddings ./embeddings \\
--out_dir ./results
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from typing import Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset
from fire import Fire
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score,
f1_score,
classification_report,
confusion_matrix,
silhouette_score,
)
from sklearn.model_selection import train_test_split
# =============================================================================
# Data Loading
# =============================================================================
def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
if not os.path.isdir(embedding_path):
raise FileNotFoundError(
f"Dataset directory not found: {embedding_path}. "
"Ensure this is a valid HuggingFace dataset directory."
)
# Load HuggingFace dataset from disk
dataset = load_from_disk(embedding_path)
return dataset
# =============================================================================
# Data Splitting
# =============================================================================
def split_by_session(
features: np.ndarray,
labels: np.ndarray,
session_ids: np.ndarray,
train_session: str = "1",
test_session: str = "2",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
"""
Split data by recording session for temporal generalization evaluation.
This split strategy tests whether learned patterns generalize across time.
Training on session 1 and testing on session 2 simulates real-world deployment
where models must work on future recordings.
Args:
features: Feature matrix of shape (num_samples, num_features)
labels: Label array of shape (num_samples,)
session_ids: Session identifier for each sample
train_session: Session ID to use for training (default: "1")
test_session: Session ID to use for testing (default: "2")
Returns:
Tuple of (X_train, X_test, y_train, y_test, split_description)
"""
# Create boolean masks for train and test sets
train_mask = session_ids == train_session
test_mask = session_ids == test_session
# Apply masks to create train/test splits
X_train = features[train_mask]
X_test = features[test_mask]
y_train = labels[train_mask]
y_test = labels[test_mask]
split_description = f"session({train_session}->train, {test_session}->test)"
return X_train, X_test, y_train, y_test, split_description
def split_random(
features: np.ndarray,
labels: np.ndarray,
test_size: float = 0.2,
random_state: int = 0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
"""
Split data randomly with stratification for in-distribution evaluation.
Stratification ensures each class has proportional representation in
both train and test sets, preventing class imbalance issues.
Args:
features: Feature matrix of shape (num_samples, num_features)
labels: Label array of shape (num_samples,)
test_size: Fraction of data to use for testing (default: 0.2)
random_state: Random seed for reproducibility
Returns:
Tuple of (X_train, X_test, y_train, y_test, split_description)
"""
X_train, X_test, y_train, y_test = train_test_split(
features,
labels,
test_size=test_size,
random_state=random_state,
stratify=labels,
)
split_description = "random"
return X_train, X_test, y_train, y_test, split_description
# =============================================================================
# Model Training and Evaluation
# =============================================================================
def create_classifier_pipeline(
n_estimators: int = 200,
max_depth: int = None,
min_samples_split: int = 2,
min_samples_leaf: int = 1,
random_state: int = 0,
) -> Pipeline:
"""
Create a scikit-learn pipeline for user classification using Random Forest.
Pipeline Architecture:
----------------------
1. StandardScaler: Z-score normalization of features
- Centers features (mean=0) and scales to unit variance (std=1)
- While Random Forest is scale-invariant, scaling helps with
consistent feature importance interpretation
2. RandomForestClassifier: Ensemble of decision trees
- Builds multiple decision trees on random subsets of data (bagging)
- Each tree uses random subset of features at each split
- Final prediction is majority vote across all trees
- Provides feature_importances_ for interpretability
Random Forest Hyperparameters:
------------------------------
- n_estimators: Number of trees in the forest (more = better but slower)
- max_depth: Maximum tree depth (None = expand until pure leaves)
- min_samples_split: Minimum samples to split internal node
- min_samples_leaf: Minimum samples required at leaf node
Args:
n_estimators: Number of trees (default: 200)
max_depth: Maximum depth of trees (default: None, fully grown)
min_samples_split: Min samples for splitting (default: 2)
min_samples_leaf: Min samples at leaf (default: 1)
random_state: Random seed for reproducibility
Returns:
Configured sklearn Pipeline ready for .fit() and .predict()
"""
pipeline = Pipeline([
# Step 1: Feature normalization
("scaler", StandardScaler(
with_mean=True, # Subtract mean (center the data)
with_std=True, # Divide by standard deviation (scale to unit variance)
)),
# Step 2: Random Forest classification
("classifier", RandomForestClassifier(
n_estimators=n_estimators, # Number of trees in the forest
max_depth=max_depth, # Maximum depth (None = unlimited)
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
n_jobs=-1, # Use all CPU cores for parallel tree building
random_state=random_state, # For reproducibility
class_weight="balanced", # Handle class imbalance automatically
oob_score=True, # Enable out-of-bag error estimation
)),
])
return pipeline
def evaluate_classifier(
y_true: np.ndarray,
y_pred: np.ndarray,
) -> Tuple[float, float, str, np.ndarray, list]:
"""
Compute classification metrics and confusion matrix.
Metrics Computed:
- Accuracy: Overall fraction of correct predictions
- Macro F1: Average F1 across all classes
- Per-class report: Precision, recall, F1 for each user
- Confusion matrix: Detailed breakdown of predictions vs ground truth
Args:
y_true: Ground truth labels
y_pred: Predicted labels
Returns:
Tuple of (accuracy, f1, classification_report, confusion_matrix, class_labels)
"""
# Compute scalar metrics
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro') # Multiclass: use macro-averaged F1
# Generate detailed per-class report
report = classification_report(y_true, y_pred, digits=4)
# Compute confusion matrix: Get all unique classes from both true and predicted labels
class_labels = sorted(set(y_true) | set(y_pred))
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
return accuracy, f1, report, cm, class_labels
# =============================================================================
# Visualization
# =============================================================================
def save_confusion_matrix_plot(
confusion_mat: np.ndarray,
class_labels: list,
output_path: str,
) -> None:
"""
Create and save a confusion matrix heatmap visualization.
The confusion matrix shows:
- Rows: True class labels
- Columns: Predicted class labels
- Cell values: Count of samples with that (true, predicted) combination
- Diagonal: Correct predictions
- Off-diagonal: Misclassifications
Args:
confusion_mat: Square matrix of shape (num_classes, num_classes)
class_labels: List of class names for axis labels
output_path: File path to save the plot (PDF format recommended)
"""
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot heatmap using imshow
im = ax.imshow(confusion_mat, aspect="auto", cmap="Blues")
# Add colorbar to show value scale
plt.colorbar(im, ax=ax)
# Set labels and title
ax.set_title("Confusion Matrix (User Classification)")
ax.set_xlabel("Predicted User")
ax.set_ylabel("True User")
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved confusion matrix: {output_path}")
def save_metrics_report(
output_path: str,
split_description: str,
train_size: int,
test_size: int,
accuracy: float,
macro_f1: float,
classification_report_str: str,
oob_score: float = None,
silhouette: float = None,
) -> None:
with open(output_path, "w", encoding="utf-8") as f:
# Write summary statistics
f.write(f"split_used : {split_description}\n")
f.write(f"train_size : {train_size}\n")
f.write(f"test_size : {test_size}\n")
# Silhouette score (embedding quality metric)
if silhouette is not None:
f.write(f"silhouette : {silhouette:.4f}\n")
f.write(f"accuracy : {accuracy:.4f}\n")
f.write(f"macro_f1 : {macro_f1:.4f}\n")
# Add OOB score if available (Random Forest specific)
if oob_score is not None:
f.write(f"oob_score : {oob_score:.4f}\n")
f.write("\n")
# Write detailed per-class report
f.write(classification_report_str)
print(f"[DONE] Saved metrics report: {output_path}")
def save_feature_importance_plot(
feature_importances: np.ndarray,
output_path: str,
top_k: int = 50,
) -> None:
# Get indices of top-k most important features
top_indices = np.argsort(feature_importances)[::-1][:top_k]
top_importances = feature_importances[top_indices]
# Create figure
fig, ax = plt.subplots(figsize=(12, 8))
# Create horizontal bar chart
y_positions = np.arange(len(top_indices))
ax.barh(y_positions, top_importances, color="steelblue", alpha=0.8)
# Set labels
ax.set_yticks(y_positions)
ax.set_yticklabels([f"emb_{i}" for i in top_indices], fontsize=8)
ax.invert_yaxis() # Highest importance at top
ax.set_xlabel("Feature Importance (Mean Decrease in Impurity)")
ax.set_ylabel("Embedding Dimension")
ax.set_title(f"Top {top_k} Most Important Embedding Dimensions")
# Save as vector PDF
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved feature importance plot: {output_path}")
def save_feature_importance_csv(
feature_importances: np.ndarray,
output_path: str,
) -> None:
importance_df = pd.DataFrame({
"feature": [f"emb_{i}" for i in range(len(feature_importances))],
"importance": feature_importances,
})
importance_df = importance_df.sort_values("importance", ascending=False)
importance_df = importance_df.reset_index(drop=True)
importance_df["rank"] = range(1, len(importance_df) + 1)
importance_df = importance_df[["rank", "feature", "importance"]]
importance_df.to_csv(output_path, index=False)
print(f"[DONE] Saved feature importance CSV: {output_path}")
# =============================================================================
# Command Line Interface
# =============================================================================
def main(
embeddings: str,
out_dir: str,
split_mode: str = "session",
test_size: float = 0.2,
n_estimators: int = 200,
max_depth: int = None,
labels: str = None,
) -> None:
# Validate split_mode argument
if split_mode not in ["session", "random"]:
raise ValueError(f"Invalid split_mode: {split_mode}. Use 'session' or 'random'.")
# Create output directory
os.makedirs(out_dir, exist_ok=True)
print(f"[INFO] Loading embeddings from: {embeddings}")
dataset = load_embeddings_with_metadata(embeddings)
# Filter by sleep stage labels if specified
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: str(x["label"]) in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
X = np.array(dataset["embedding"], dtype=np.float32)
y = np.array([str(uid) for uid in dataset["user_id"]])
session_ids = np.array([str(sid) for sid in dataset["session_id"]])
print(f"[INFO] Loaded {len(y)} samples with {X.shape[1]}-dimensional embeddings")
print(f"[INFO] Number of unique users: {len(np.unique(y))}")
if split_mode == "session":
has_session_1 = (session_ids == "1").sum() > 0
has_session_2 = (session_ids == "2").sum() > 0
if has_session_1 and has_session_2:
X_train, X_test, y_train, y_test, split_desc = split_by_session(
X, y, session_ids
)
else:
print("[WARN] Missing session data, falling back to random split")
X_train, X_test, y_train, y_test, split_desc = split_random(
X, y, test_size
)
split_desc = "random(fallback)"
else:
X_train, X_test, y_train, y_test, split_desc = split_random(
X, y, test_size
)
print(f"[INFO] Split: {split_desc}")
print(f"[INFO] Train size: {len(y_train)}, Test size: {len(y_test)}")
silhouette_avg = silhouette_score(X, y)
print(f"[INFO] Silhouette Score (user clusters): {silhouette_avg:.4f}")
print("[INFO] Training Random Forest classifier...")
print(f"[INFO] Hyperparameters: n_estimators={n_estimators}, max_depth={max_depth}")
classifier = create_classifier_pipeline(
n_estimators=n_estimators,
max_depth=max_depth,
)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
rf_model = classifier.named_steps["classifier"]
oob_score = rf_model.oob_score_ if hasattr(rf_model, "oob_score_") else None
feature_importances = rf_model.feature_importances_
print("[INFO] Evaluating classifier performance...")
accuracy, macro_f1, report, cm, classes = evaluate_classifier(y_test, y_pred)
metrics_path = os.path.join(out_dir, "user_cls_metrics.txt")
save_metrics_report(
output_path=metrics_path,
split_description=split_desc,
train_size=len(y_train),
test_size=len(y_test),
accuracy=accuracy,
macro_f1=macro_f1,
classification_report_str=report,
oob_score=oob_score,
silhouette=silhouette_avg,
)
confusion_path = os.path.join(out_dir, "user_cls_confusion.pdf")
save_confusion_matrix_plot(cm, classes, confusion_path)
importance_plot_path = os.path.join(out_dir, "feature_importance.pdf")
save_feature_importance_plot(feature_importances, importance_plot_path)
importance_csv_path = os.path.join(out_dir, "feature_importance.csv")
save_feature_importance_csv(feature_importances, importance_csv_path)
print("\n" + "=" * 50)
print("RANDOM FOREST CLASSIFICATION RESULTS")
print("=" * 50)
print(f"Split Strategy : {split_desc}")
print(f"Silhouette : {silhouette_avg:.4f}")
print(f"Accuracy : {accuracy:.4f}")
print(f"Macro F1 : {macro_f1:.4f}")
if oob_score is not None:
print(f"OOB Score : {oob_score:.4f}")
print(f"Top 5 Important Features:")
top5_idx = np.argsort(feature_importances)[::-1][:5]
for rank, idx in enumerate(top5_idx, 1):
print(f" {rank}. emb_{idx}: {feature_importances[idx]:.4f}")
print("=" * 50)
if __name__ == "__main__":
Fire(main)

View File

@@ -1,504 +0,0 @@
"""
SBERT Time Series Embedding Extraction and Visualization Pipeline
This module provides functionality to
1. Extract embeddings from time series features using SBERT (Sentence-BERT)
2. Visualize embeddings using dimensionality reduction (t-SNE)
SBERT Overview:
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
and triplet network structures to derive semantically meaningful sentence
embeddings. It's designed for semantic similarity tasks and produces
fixed-size dense vector representations of text.
Key Features:
- Semantic understanding: Captures meaning rather than just word presence
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
- Efficient: Optimized for sentence-level tasks
Embedding Strategy:
We convert time series features into textual descriptions, then use SBERT
to generate embeddings. This approach treats feature vectors as "sentences"
where each feature-value pair is a "word" in the description.
Processing Pipeline:
1. Feature extraction from time series (statistical features)
2. Textualization: Convert features to natural language description
3. SBERT encoding: Generate 384-dimensional embeddings
4. Aggregation: Store embeddings with metadata (user_id, session_id, label)
Usage:
# Extract embeddings
python gen_plot.py extract --data_root /path/to/data --out_dir ./embeddings
# Visualize with t-SNE
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Tuple
import numpy as np
import matplotlib.pyplot as plt
import torch
from datasets import load_from_disk, Dataset
from fire import Fire
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
# =============================================================================
# Constants
# =============================================================================
# EEG channel names used in Sleep-EDF dataset
# Fpz-Cz: Frontal-Central electrode pair
# Pz-Oz: Parietal-Occipital electrode pair
EEG_CHANNEL_1 = "EEG Fpz-Cz"
EEG_CHANNEL_2 = "EEG Pz-Oz"
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class SBERT:
"""
Extracts fixed-dimensional embeddings from time series features using SBERT.
Uses Sentence-BERT to convert textualized feature descriptions into dense
vector representations. The textualization process converts statistical
features into natural language, which SBERT then encodes semantically.
Architecture:
Time Series → Feature Extraction → Textualization → SBERT Encoder → Embedding
Processing Pipeline:
1. Extract statistical features from time series (mean, std, etc.)
2. Textualize: Convert features to natural language description
3. SBERT encoding: Generate 384-dimensional semantic embeddings
4. Output: Fixed-size embedding vector per sample
Model: all-MiniLM-L6-v2
- Lightweight BERT variant optimized for sentence embeddings
- 384-dimensional output embeddings
- Fast inference with good semantic understanding
Attributes:
model: SentenceTransformer instance for encoding textualized features
"""
def __init__(self):
"""
Initialize the SBERT embedder with pre-trained model.
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
optimized for sentence similarity tasks.
"""
self.model = SentenceTransformer("all-MiniLM-L6-v2")
@staticmethod
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
"""
Discover all user/session directories under data_root.
Uses glob pattern matching for cleaner directory traversal.
Expected structure: data_root/user_id/session_id/
Args:
data_root: Root directory containing user/session subfolders
Returns:
List of (user_id, session_id, session_path) tuples
"""
discovered_paths = []
# Use glob to find all session directories (2 levels deep)
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
if not os.path.isdir(session_path):
continue
# Extract user_id and session_id from path
session_id = os.path.basename(session_path)
user_id = os.path.basename(os.path.dirname(session_path))
discovered_paths.append((user_id, session_id, session_path))
return discovered_paths
def textualize_sample(self, sample: Dict[str, Any]) -> str:
"""
Convert a feature dictionary into a natural language description.
This textualization step is crucial for SBERT, which expects text input.
The description provides context about the sleep stage and lists all
extracted features in a structured format.
Args:
sample: Dictionary mapping feature names to their values
Example: {"mean": 0.5, "std": 0.2, "max": 1.0, ...}
Returns:
Natural language string describing the features
"""
sentence = (
"While sleeping (one of the stages: W, N1, N2, N3, REM), "
"I have sensor data features measured from two channels: EEG Fpz-Cz and EEG Pz-Oz.\n"
"The features are as follows: \n"
)
for k, v in sample.items():
sentence += f" - {k}: {self.format_feature(v)}\n"
sentence.strip()
return sentence
def format_feature(self, value: Any) -> str:
"""
Format a feature value for textual representation.
Floats are rounded to 2 decimal places for readability,
other types are converted to strings.
Args:
value: Feature value (float, int, or other type)
Returns:
Formatted string representation
"""
if isinstance(value, float):
return f"{value:.2f}"
return str(value)
def compute_embedding(self, batch: Dict[str, Any]) -> np.ndarray:
"""
Generate embedding vectors from feature batches using SBERT.
Processing Pipeline:
1. Extract feature dictionaries from batch
2. Textualize each sample's features into natural language
3. Encode textual descriptions using SBERT
4. Return fixed-size embedding vectors
Args:
batch: HuggingFace dataset batch from slicing (dataset[start:end])
Format: {"features": [{"mean": 0.5, "std": 0.2, ...}, ...]}
Returns:
Embedding array of shape (batch_size, embedding_dim)
For all-MiniLM-L6-v2: (batch_size, 384)
"""
# Extract feature dictionaries from batch
samples = batch["features"]
# Convert each sample's features to text
text_samples = []
for sample in samples:
text_samples.append(self.textualize_sample(sample))
# Encode text descriptions using SBERT
# Returns numpy array of shape (batch_size, 384)
embeddings = self.model.encode(text_samples)
return embeddings
def extract_embeddings(
self,
data_root: str,
batch_size: int = 32,
label: str = "N1"
) -> Dataset:
"""
Extract embeddings from all sessions under the data root directory.
Iterates through all user/session combinations, processes features
in batches, and aggregates results with metadata. Filters by sleep
stage label to focus on specific sleep stages.
Args:
data_root: Root directory containing user/session data folders
batch_size: Number of samples to process together.
Larger = faster but more memory.
32 is a good balance for most systems.
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM")
Returns:
HuggingFace Dataset with columns:
- user_id, session_id, idx, label (metadata)
- embedding (384-dim vector from all-MiniLM-L6-v2)
"""
session_paths = self.discover_session_paths(data_root)
print(f"[INFO] Discovered {len(session_paths)} sessions")
all_embeddings = []
all_user_ids = []
all_session_ids = []
all_idxs = []
all_labels = []
for user_id, session_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# Shuffle dataset for randomness
dataset = dataset.shuffle(seed=0)
# Filter by sleep stage label
dataset = dataset.filter(lambda x: x["label"] == label)
num_samples = len(dataset)
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Process in batches to manage memory
for batch_start in range(0, num_samples, batch_size):
batch_end = min(batch_start + batch_size, num_samples)
# Slice dataset to get batch
batch = dataset[batch_start:batch_end]
# Compute embeddings
embeddings = self.compute_embedding(batch)
# Collect embeddings and metadata
for i in range(embeddings.shape[0]):
all_embeddings.append(embeddings[i].tolist())
all_user_ids.append(str(batch["user_id"][i]))
all_session_ids.append(str(batch["session_id"][i]))
all_idxs.append(int(batch["idx"][i]))
all_labels.append(str(batch["label"][i]))
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
"session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"embedding": all_embeddings,
})
print(f"[INFO] Total samples: {len(result_dataset)}")
return result_dataset
def save_embeddings(
self,
dataset: Dataset,
output_dir: str
) -> None:
"""
Save embeddings dataset to disk in HuggingFace format.
Args:
dataset: HuggingFace Dataset containing embeddings and metadata
output_dir: Directory path to save the dataset
"""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load saved embeddings dataset from disk.
Args:
embedding_dir: Directory path containing the saved HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
dimensionality reduction technique that preserves local structure.
Points that are similar in high dimensions stay close in 2D.
Args:
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity parameter (typically 5-50).
Higher values consider more neighbors, creating smoother layouts.
Rule of thumb: perplexity ~ sqrt(num_samples)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0, # For reproducibility
perplexity=perplexity,
max_iter=1000, # Usually sufficient for convergence
init='random',
learning_rate='auto', # Let sklearn choose optimal learning rate
)
return tsne.fit_transform(embeddings)
def create_scatter_plot(
coordinates: np.ndarray,
labels: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot with categorical coloring.
Args:
coordinates: 2D array of shape (num_points, 2)
labels: Category labels for each point (string array)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Get unique labels for legend
unique_labels = sorted(set(labels))
# Select colormap based on number of categories
# tab10: 10 distinct colors, tab20: 20 distinct colors
colormap = plt.cm.tab10 if len(unique_labels) <= 10 else plt.cm.tab20
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for proper legend
for idx, label in enumerate(unique_labels):
mask = labels == label
ax.scatter(
coordinates[mask, 0],
coordinates[mask, 1],
c=[colormap(idx % 20)],
s=15,
label=label,
alpha=0.7,
)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(loc='best', markerscale=2)
# Save figure as vector PDF (scalable, ideal for publications)
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT embedding extraction and visualization.
Provides two main commands:
- extract: Generate embeddings from time series features
- plot: Visualize embeddings with t-SNE
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
out_dir: str = "./embeddings/REM",
batch_size: int = 32,
label: str = "REM"
) -> None:
"""
Extract embeddings from time series data.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM')
"""
embedder = SBERT()
dataset = embedder.extract_embeddings(data_root, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = "./embeddings/W",
out_dir: str = "./plots/W",
perplexity: float = 10.0,
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE.
Args:
emb_dir: Directory containing the HuggingFace embeddings dataset
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 10.0)
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
"""
os.makedirs(out_dir, exist_ok=True)
# Load saved embeddings dataset
dataset = SBERT.load_embeddings(emb_dir)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations
create_scatter_plot(
coordinates_2d,
np.array(dataset["label"]),
"t-SNE Visualization (Colored by Sleep Stage)",
os.path.join(out_dir, "tsne_by_label.pdf")
)
create_scatter_plot(
coordinates_2d,
np.array(dataset["user_id"]),
"t-SNE Visualization (Colored by User ID)",
os.path.join(out_dir, "tsne_by_user.pdf")
)
if __name__ == "__main__":
Fire(CLI)

View File

@@ -1,529 +0,0 @@
"""
SBERT Embedding Extraction and Visualization with Metadata (Age and Sex)
This module provides functionality to:
1. Extract SBERT embeddings from time series features
2. Visualize embeddings colored by subject metadata (age and sex) instead of user IDs
Features:
1. Extract embeddings from time series features using SBERT
2. Load embeddings from HuggingFace dataset
3. Load subject metadata from XLS file
4. Map user IDs to age and sex information
5. Visualize embeddings with t-SNE colored by age (continuous) and sex (categorical)
Usage:
# Extract embeddings for all labels
python gen_plot_metadata.py extract --data_root /path/to/data --out_dir ./embeddings/all_labels
# Extract embeddings for a single label
python gen_plot_metadata.py extract --data_root /path/to/data --out_dir ./embeddings/W --label W
# Visualize by age from single label directory
python gen_plot_metadata.py plot --emb_dir ./embeddings/W --out_dir ./plots/W --color_by age
# Visualize by age from all label directories (W, REM, N1, N2, N3)
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by age
# Visualize by sex from all labels
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by sex
# Create both age and sex plots from all labels
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by both
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset, concatenate_datasets
from fire import Fire
from sklearn.manifold import TSNE
from gen_plot import SBERT, reduce_to_2d_tsne
# =============================================================================
# Constants
# =============================================================================
SUBJECT_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/SC-subjects.xls"
# =============================================================================
# Metadata Loading
# =============================================================================
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
"""
Load subject metadata from XLS file.
The XLS file contains subject information including:
- subject: Subject ID (e.g., "SC4001", "SC4002")
- age: Age of the subject
- sex (F=1): Sex (1 = Female, 0 = Male)
Args:
subject_path: Path to the SC-subjects.xls file
Returns:
Dictionary mapping subject IDs to metadata dictionaries
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
"""
df = pd.read_excel(subject_path)
subject_info = {}
for index, row in df.iterrows():
subject_id = str(row["subject"]).strip()
subject_info[subject_id] = {
"age": int(row["age"]) if pd.notna(row["age"]) else None,
"sex": int(row["sex (F=1)"]) if pd.notna(row["sex (F=1)"]) else None,
}
return subject_info
def map_user_ids_to_metadata(
user_ids: np.ndarray,
subject_metadata: Dict[str, Dict[str, Any]]
) -> tuple:
"""
Map user IDs from dataset to age and sex metadata.
User IDs in the dataset are typically 2-digit codes (e.g., "40", "41").
Subject IDs in the metadata file are typically 4-character codes (e.g., "SC40", "SC41").
We need to match them appropriately.
Args:
user_ids: Array of user IDs from the dataset
subject_metadata: Dictionary of subject metadata loaded from XLS file
Returns:
Tuple of (ages, sexes) as numpy arrays
Missing values are set to None
"""
ages = []
sexes = []
for user_id in user_ids:
age = subject_metadata[user_id]["age"]
sex = subject_metadata[user_id]["sex"]
ages.append(age)
sexes.append(sex)
return np.array(ages), np.array(sexes)
# =============================================================================
# Visualization Functions
# =============================================================================
def create_scatter_plot_by_age(
coordinates: np.ndarray,
ages: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by age (continuous colormap).
Uses a continuous colormap (viridis) to show age distribution.
Ages are mapped to colors on a gradient scale.
Args:
coordinates: 2D array of shape (num_points, 2)
ages: Age values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing age data
# Convert None values to NaN for proper numpy handling
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
valid_mask = ~np.isnan(ages_float)
valid_coords = coordinates[valid_mask]
valid_ages = ages_float[valid_mask]
if len(valid_ages) == 0:
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
return
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Create scatter plot with continuous colormap
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_ages,
cmap='viridis', # Continuous colormap for granular age visualization
s=15,
alpha=0.7,
edgecolors='none',
)
# Add colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Age (years)', rotation=270, labelpad=20)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
def create_scatter_plot_by_sex(
coordinates: np.ndarray,
sexes: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by sex (categorical).
Uses discrete colors for different sex categories.
Sex encoding: 1 = Female, 0 = Male
Args:
coordinates: 2D array of shape (num_points, 2)
sexes: Sex values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing sex data
# Convert None values to NaN for proper numpy handling
sexes_float = np.array([float(s) if s is not None else np.nan for s in sexes])
valid_mask = ~np.isnan(sexes_float)
valid_coords = coordinates[valid_mask]
valid_sexes = sexes_float[valid_mask].astype(int)
if len(valid_sexes) == 0:
print(f"[WARN] No valid sex data found. Skipping plot: {output_path}")
return
# Map sex codes to labels
sex_labels = {0: "Male", 1: "Female"}
unique_sexes = sorted(set(valid_sexes))
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for proper legend
colors = ['steelblue', 'coral'] # Blue for Male, Coral for Female
for idx, sex_code in enumerate(unique_sexes):
mask = valid_sexes == sex_code
ax.scatter(
valid_coords[mask, 0],
valid_coords[mask, 1],
c=colors[sex_code % len(colors)],
s=15,
label=sex_labels.get(sex_code, f"Sex {sex_code}"),
alpha=0.7,
)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(loc='best', markerscale=2)
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
sex_counts = {sex_labels.get(s, f"Sex {s}"): (valid_sexes == s).sum() for s in unique_sexes}
print(f"[INFO] Sex distribution: {sex_counts}")
print(f"[INFO] Points with valid sex: {len(valid_sexes)}/{len(sexes)}")
# =============================================================================
# Embedding Extraction Utilities
# =============================================================================
def extract_embeddings_for_all_labels(
data_root: str,
out_dir: str,
batch_size: int = 32
) -> Dataset:
"""
Extract embeddings for all sleep stage labels and combine them.
Extracts embeddings for each label (W, REM, N1, N2, N3) separately,
then concatenates them into a single dataset.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for the combined HuggingFace dataset
batch_size: Batch size for inference (default: 32)
Returns:
Combined HuggingFace Dataset with all labels
"""
embedder = SBERT()
all_labels = ["W", "REM", "N1", "N2", "N3"]
datasets = []
for label in all_labels:
print(f"\n[INFO] Extracting embeddings for label: {label}")
dataset = embedder.extract_embeddings(data_root, batch_size, label)
print(f"[INFO] Extracted {len(dataset)} samples for label {label}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"\n[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
# Save combined dataset
embedder.save_embeddings(combined_dataset, out_dir)
return combined_dataset
# =============================================================================
# Data Loading Utilities
# =============================================================================
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
"""
Load embeddings from all label subdirectories and concatenate them.
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
and loads embeddings from each, then concatenates them into a single dataset.
Args:
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
Returns:
Concatenated HuggingFace Dataset with all labels combined
"""
# Discover all label subdirectories
label_dirs = []
for item in os.listdir(embeddings_root):
item_path = os.path.join(embeddings_root, item)
if os.path.isdir(item_path):
# Check if it's a valid HuggingFace dataset directory
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
label_dirs.append((item, item_path))
if len(label_dirs) == 0:
raise ValueError(
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
)
label_dirs.sort() # Sort for consistent ordering
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
# Load datasets from each label directory
datasets = []
for label_name, label_path in label_dirs:
print(f"[INFO] Loading embeddings from: {label_path}")
dataset = load_from_disk(label_path)
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return combined_dataset
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT embedding extraction and visualization with metadata.
Provides:
- extract: Generate embeddings from time series features
- plot: Visualize embeddings colored by age or sex instead of user ID
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
out_dir: str = "./embeddings/all_labels",
batch_size: int = 32,
label: Optional[str] = None
) -> None:
"""
Extract embeddings from time series features using SBERT.
Can extract for a single label or all labels at once.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
If None or 'all', extracts for all labels (default: None for all labels)
"""
embedder = SBERT()
if label is None or label == "all":
# Extract for all labels
print(f"[INFO] Extracting embeddings for all labels")
extract_embeddings_for_all_labels(data_root, out_dir, batch_size)
else:
# Extract for single label
print(f"[INFO] Extracting embeddings for label: {label}")
dataset = embedder.extract_embeddings(data_root, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = None,
embeddings_root: str = "./embeddings",
out_dir: str = "./plots/all_labels",
subject_path: str = SUBJECT_PATH,
perplexity: float = 10.0,
color_by: str = "age",
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE, colored by age or sex.
Can load from either a single label directory or all label directories.
Args:
emb_dir: Single directory containing the HuggingFace embeddings dataset
(e.g., "./embeddings/W"). If provided, only this directory is used.
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
Used only if emb_dir is not provided.
out_dir: Output directory for visualization plots (PDF)
subject_path: Path to SC-subjects.xls file with metadata
perplexity: t-SNE perplexity parameter (default: 10.0)
color_by: What to color by - 'age', 'sex', or 'both' (default: 'age')
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
This filters the already-loaded data, not which directories to load.
"""
# Validate color_by argument
if color_by not in ["age", "sex", "both", "all"]:
raise ValueError(f"Invalid color_by: {color_by}. Use 'age', 'sex', or 'both'.")
os.makedirs(out_dir, exist_ok=True)
# Load embeddings: either from single directory or all label directories
if emb_dir is not None:
# Load from single directory
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
dataset = SBERT.load_embeddings(emb_dir)
else:
# Load from all label directories
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
dataset = load_embeddings_from_all_labels(embeddings_root)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Load subject metadata
print(f"[INFO] Loading subject metadata from: {subject_path}")
subject_metadata = load_subject_metadata(subject_path)
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
# Map user IDs to metadata
user_ids = np.array([str(uid) for uid in dataset["user_id"]])
user_ids_ = [str(int(uid)) for uid in user_ids]
ages, sexes = map_user_ids_to_metadata(user_ids_, subject_metadata)
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations based on color_by parameter
if color_by == "age":
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
elif color_by == "sex":
create_scatter_plot_by_sex(
coordinates_2d,
sexes,
"t-SNE Visualization (Colored by Sex)",
os.path.join(out_dir, "tsne_by_sex.pdf")
)
elif color_by == "both" or color_by == "all":
# Create both plots
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
create_scatter_plot_by_sex(
coordinates_2d,
sexes,
"t-SNE Visualization (Colored by Sex)",
os.path.join(out_dir, "tsne_by_sex.pdf")
)
if __name__ == "__main__":
Fire(CLI)

View File

@@ -1,527 +0,0 @@
"""
User Classification from SBERT Embeddings
This module evaluates how well SBERT embeddings capture user-specific patterns
by training a simple linear classifier to predict user identity from embeddings.
Motivation:
If embeddings contain user-distinguishing information, a classifier should be able
to predict which user a time series belongs to. High accuracy suggests that the
embeddings capture individual characteristics.
Experimental Design:
- Task: Multi-class classification
- Model: Random Forest with ensemble of decision trees
- Captures non-linear relationships in embedding space
- Provides feature importance scores for interpretability
- Robust to overfitting through bagging and random feature selection
- Split Strategy:
1. Session-based: Train on session 1, test on session 2
2. Random: Standard train/test split
Session-based split is more challenging and realistic because:
- Tests whether user patterns are stable across different recording sessions
- Avoids data leakage from same-session samples in train and test
Random Forest Advantages over Logistic Regression:
- Handles non-linear decision boundaries
- Feature importance reveals which embedding dimensions matter most
- No assumption about data distribution
- Naturally handles multi-class classification
Output:
- Classification metrics
- Confusion matrix visualization
Usage:
python simple_user_classifier.py \\
--embeddings ./embeddings \\
--out_dir ./results
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from typing import Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset
from fire import Fire
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score,
f1_score,
classification_report,
confusion_matrix,
silhouette_score,
)
from sklearn.model_selection import train_test_split
# =============================================================================
# Data Loading
# =============================================================================
def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
if not os.path.isdir(embedding_path):
raise FileNotFoundError(
f"Dataset directory not found: {embedding_path}. "
"Ensure this is a valid HuggingFace dataset directory."
)
# Load HuggingFace dataset from disk
dataset = load_from_disk(embedding_path)
return dataset
# =============================================================================
# Data Splitting
# =============================================================================
def split_by_session(
features: np.ndarray,
labels: np.ndarray,
session_ids: np.ndarray,
train_session: str = "1",
test_session: str = "2",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
"""
Split data by recording session for temporal generalization evaluation.
This split strategy tests whether learned patterns generalize across time.
Training on session 1 and testing on session 2 simulates real-world deployment
where models must work on future recordings.
Args:
features: Feature matrix of shape (num_samples, num_features)
labels: Label array of shape (num_samples,)
session_ids: Session identifier for each sample
train_session: Session ID to use for training (default: "1")
test_session: Session ID to use for testing (default: "2")
Returns:
Tuple of (X_train, X_test, y_train, y_test, split_description)
"""
# Create boolean masks for train and test sets
train_mask = session_ids == train_session
test_mask = session_ids == test_session
# Apply masks to create train/test splits
X_train = features[train_mask]
X_test = features[test_mask]
y_train = labels[train_mask]
y_test = labels[test_mask]
split_description = f"session({train_session}->train, {test_session}->test)"
return X_train, X_test, y_train, y_test, split_description
def split_random(
features: np.ndarray,
labels: np.ndarray,
test_size: float = 0.2,
random_state: int = 0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
"""
Split data randomly with stratification for in-distribution evaluation.
Stratification ensures each class has proportional representation in
both train and test sets, preventing class imbalance issues.
Args:
features: Feature matrix of shape (num_samples, num_features)
labels: Label array of shape (num_samples,)
test_size: Fraction of data to use for testing (default: 0.2)
random_state: Random seed for reproducibility
Returns:
Tuple of (X_train, X_test, y_train, y_test, split_description)
"""
X_train, X_test, y_train, y_test = train_test_split(
features,
labels,
test_size=test_size,
random_state=random_state,
stratify=labels,
)
split_description = "random"
return X_train, X_test, y_train, y_test, split_description
# =============================================================================
# Model Training and Evaluation
# =============================================================================
def create_classifier_pipeline(
n_estimators: int = 200,
max_depth: int = None,
min_samples_split: int = 2,
min_samples_leaf: int = 1,
random_state: int = 0,
) -> Pipeline:
"""
Create a scikit-learn pipeline for user classification using Random Forest.
Pipeline Architecture:
----------------------
1. StandardScaler: Z-score normalization of features
- Centers features (mean=0) and scales to unit variance (std=1)
- While Random Forest is scale-invariant, scaling helps with
consistent feature importance interpretation
2. RandomForestClassifier: Ensemble of decision trees
- Builds multiple decision trees on random subsets of data (bagging)
- Each tree uses random subset of features at each split
- Final prediction is majority vote across all trees
- Provides feature_importances_ for interpretability
Random Forest Hyperparameters:
------------------------------
- n_estimators: Number of trees in the forest (more = better but slower)
- max_depth: Maximum tree depth (None = expand until pure leaves)
- min_samples_split: Minimum samples to split internal node
- min_samples_leaf: Minimum samples required at leaf node
Args:
n_estimators: Number of trees (default: 200)
max_depth: Maximum depth of trees (default: None, fully grown)
min_samples_split: Min samples for splitting (default: 2)
min_samples_leaf: Min samples at leaf (default: 1)
random_state: Random seed for reproducibility
Returns:
Configured sklearn Pipeline ready for .fit() and .predict()
"""
pipeline = Pipeline([
# Step 1: Feature normalization
("scaler", StandardScaler(
with_mean=True, # Subtract mean (center the data)
with_std=True, # Divide by standard deviation (scale to unit variance)
)),
# Step 2: Random Forest classification
("classifier", RandomForestClassifier(
n_estimators=n_estimators, # Number of trees in the forest
max_depth=max_depth, # Maximum depth (None = unlimited)
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
n_jobs=-1, # Use all CPU cores for parallel tree building
random_state=random_state, # For reproducibility
class_weight="balanced", # Handle class imbalance automatically
oob_score=True, # Enable out-of-bag error estimation
)),
])
return pipeline
def evaluate_classifier(
y_true: np.ndarray,
y_pred: np.ndarray,
) -> Tuple[float, float, str, np.ndarray, list]:
"""
Compute classification metrics and confusion matrix.
Metrics Computed:
- Accuracy: Overall fraction of correct predictions
- Macro F1: Average F1 across all classes
- Per-class report: Precision, recall, F1 for each user
- Confusion matrix: Detailed breakdown of predictions vs ground truth
Args:
y_true: Ground truth labels
y_pred: Predicted labels
Returns:
Tuple of (accuracy, f1, classification_report, confusion_matrix, class_labels)
"""
# Compute scalar metrics
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro') # Multiclass: use macro-averaged F1
# Generate detailed per-class report
report = classification_report(y_true, y_pred, digits=4)
# Compute confusion matrix: Get all unique classes from both true and predicted labels
class_labels = sorted(set(y_true) | set(y_pred))
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
return accuracy, f1, report, cm, class_labels
# =============================================================================
# Visualization
# =============================================================================
def save_confusion_matrix_plot(
confusion_mat: np.ndarray,
class_labels: list,
output_path: str,
) -> None:
"""
Create and save a confusion matrix heatmap visualization.
The confusion matrix shows:
- Rows: True class labels
- Columns: Predicted class labels
- Cell values: Count of samples with that (true, predicted) combination
- Diagonal: Correct predictions
- Off-diagonal: Misclassifications
Args:
confusion_mat: Square matrix of shape (num_classes, num_classes)
class_labels: List of class names for axis labels
output_path: File path to save the plot (PDF format recommended)
"""
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot heatmap using imshow
im = ax.imshow(confusion_mat, aspect="auto", cmap="Blues")
# Add colorbar to show value scale
plt.colorbar(im, ax=ax)
# Set labels and title
ax.set_title("Confusion Matrix (User Classification)")
ax.set_xlabel("Predicted User")
ax.set_ylabel("True User")
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved confusion matrix: {output_path}")
def save_metrics_report(
output_path: str,
split_description: str,
train_size: int,
test_size: int,
accuracy: float,
macro_f1: float,
classification_report_str: str,
oob_score: float = None,
silhouette: float = None,
) -> None:
"""
Save classification metrics to a text file.
Writes summary statistics (split strategy, sizes, scores) and detailed
per-class classification report to a text file.
Args:
output_path: File path to save the metrics report
split_description: Description of train/test split strategy
train_size: Number of training samples
test_size: Number of test samples
accuracy: Overall classification accuracy
macro_f1: Macro-averaged F1 score
classification_report_str: Detailed per-class report string
oob_score: Out-of-bag score from Random Forest (optional)
silhouette: Silhouette score for embedding quality (optional)
"""
with open(output_path, "w", encoding="utf-8") as f:
# Write summary statistics
f.write(f"split_used : {split_description}\n")
f.write(f"train_size : {train_size}\n")
f.write(f"test_size : {test_size}\n")
# Silhouette score (embedding quality metric)
if silhouette is not None:
f.write(f"silhouette : {silhouette:.4f}\n")
f.write(f"accuracy : {accuracy:.4f}\n")
f.write(f"macro_f1 : {macro_f1:.4f}\n")
# Add OOB score if available (Random Forest specific)
if oob_score is not None:
f.write(f"oob_score : {oob_score:.4f}\n")
f.write("\n")
# Write detailed per-class report
f.write(classification_report_str)
print(f"[DONE] Saved metrics report: {output_path}")
def save_feature_importance_plot(
feature_importances: np.ndarray,
output_path: str,
top_k: int = 50,
) -> None:
"""
Create and save a horizontal bar chart of top-k most important features.
Visualizes which embedding dimensions contribute most to user classification.
Higher importance indicates that dimension better distinguishes between users.
Args:
feature_importances: Array of feature importance scores from Random Forest
output_path: File path to save the plot (PDF format recommended)
top_k: Number of top features to display (default: 50)
"""
# Get indices of top-k most important features
top_indices = np.argsort(feature_importances)[::-1][:top_k]
top_importances = feature_importances[top_indices]
# Create figure
fig, ax = plt.subplots(figsize=(12, 8))
# Create horizontal bar chart
y_positions = np.arange(len(top_indices))
ax.barh(y_positions, top_importances, color="steelblue", alpha=0.8)
# Set labels
ax.set_yticks(y_positions)
ax.set_yticklabels([f"emb_{i}" for i in top_indices], fontsize=8)
ax.invert_yaxis() # Highest importance at top
ax.set_xlabel("Feature Importance (Mean Decrease in Impurity)")
ax.set_ylabel("Embedding Dimension")
ax.set_title(f"Top {top_k} Most Important Embedding Dimensions")
# Save as vector PDF
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved feature importance plot: {output_path}")
def save_feature_importance_csv(
feature_importances: np.ndarray,
output_path: str,
) -> None:
"""
Save feature importance scores to CSV file with ranking.
Creates a CSV with columns: rank, feature, importance
Features are sorted by importance in descending order.
Args:
feature_importances: Array of feature importance scores from Random Forest
output_path: File path to save the CSV file
"""
importance_df = pd.DataFrame({
"feature": [f"emb_{i}" for i in range(len(feature_importances))],
"importance": feature_importances,
})
importance_df = importance_df.sort_values("importance", ascending=False)
importance_df = importance_df.reset_index(drop=True)
importance_df["rank"] = range(1, len(importance_df) + 1)
importance_df = importance_df[["rank", "feature", "importance"]]
importance_df.to_csv(output_path, index=False)
print(f"[DONE] Saved feature importance CSV: {output_path}")
# =============================================================================
# Command Line Interface
# =============================================================================
def main(
embeddings: str = "./embeddings/REM",
out_dir: str = "./results/REM",
test_size: float = 0.2,
n_estimators: int = 200,
max_depth: int = None,
) -> None:
# Create output directory
os.makedirs(out_dir, exist_ok=True)
print(f"[INFO] Loading embeddings from: {embeddings}")
dataset = load_embeddings_with_metadata(embeddings)
X = np.array(dataset["embedding"], dtype=np.float32)
y = np.array([str(uid) for uid in dataset["user_id"]])
session_ids = np.array([str(sid) for sid in dataset["session_id"]])
print(f"[INFO] Loaded {len(y)} samples with {X.shape[1]}-dimensional embeddings")
print(f"[INFO] Number of unique users: {len(np.unique(y))}")
has_session_1 = (session_ids == "1").sum() > 0
has_session_2 = (session_ids == "2").sum() > 0
if has_session_1 and has_session_2:
X_train, X_test, y_train, y_test, split_desc = split_by_session(
X, y, session_ids
)
else:
print("[WARN] Missing session data, falling back to random split")
X_train, X_test, y_train, y_test, split_desc = split_random(
X, y, test_size
)
split_desc = "random(fallback)"
print(f"[INFO] Split: {split_desc}")
print(f"[INFO] Train size: {len(y_train)}, Test size: {len(y_test)}")
silhouette_avg = silhouette_score(X, y)
print(f"[INFO] Silhouette Score (user clusters): {silhouette_avg:.4f}")
print("[INFO] Training Random Forest classifier...")
print(f"[INFO] Hyperparameters: n_estimators={n_estimators}, max_depth={max_depth}")
classifier = create_classifier_pipeline(
n_estimators=n_estimators,
max_depth=max_depth,
)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
rf_model = classifier.named_steps["classifier"]
oob_score = rf_model.oob_score_ if hasattr(rf_model, "oob_score_") else None
feature_importances = rf_model.feature_importances_
print("[INFO] Evaluating classifier performance...")
accuracy, macro_f1, report, cm, classes = evaluate_classifier(y_test, y_pred)
metrics_path = os.path.join(out_dir, "user_cls_metrics.txt")
save_metrics_report(
output_path=metrics_path,
split_description=split_desc,
train_size=len(y_train),
test_size=len(y_test),
accuracy=accuracy,
macro_f1=macro_f1,
classification_report_str=report,
oob_score=oob_score,
silhouette=silhouette_avg,
)
confusion_path = os.path.join(out_dir, "user_cls_confusion.pdf")
save_confusion_matrix_plot(cm, classes, confusion_path)
importance_plot_path = os.path.join(out_dir, "feature_importance.pdf")
save_feature_importance_plot(feature_importances, importance_plot_path)
importance_csv_path = os.path.join(out_dir, "feature_importance.csv")
save_feature_importance_csv(feature_importances, importance_csv_path)
print("\n" + "=" * 50)
print("RANDOM FOREST CLASSIFICATION RESULTS")
print("=" * 50)
print(f"Split Strategy : {split_desc}")
print(f"Silhouette : {silhouette_avg:.4f}")
print(f"Accuracy : {accuracy:.4f}")
print(f"Macro F1 : {macro_f1:.4f}")
if oob_score is not None:
print(f"OOB Score : {oob_score:.4f}")
print(f"Top 5 Important Features:")
top5_idx = np.argsort(feature_importances)[::-1][:5]
for rank, idx in enumerate(top5_idx, 1):
print(f" {rank}. emb_{idx}: {feature_importances[idx]:.4f}")
print("=" * 50)
if __name__ == "__main__":
Fire(main)

View File

@@ -1,756 +0,0 @@
"""
SBERT Metadata-Based Embedding Extraction and Visualization Pipeline
This module provides functionality to
1. Extract embeddings from user metadata (age and sex) using SBERT (Sentence-BERT)
2. Visualize embeddings using dimensionality reduction (t-SNE) colored by age
SBERT Overview:
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
and triplet network structures to derive semantically meaningful sentence
embeddings. It's designed for semantic similarity tasks and produces
fixed-size dense vector representations of text.
Key Features:
- Semantic understanding: Captures meaning rather than just word presence
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
- Efficient: Optimized for sentence-level tasks
Embedding Strategy:
Instead of using time series features, we create textual descriptions based on
user metadata (age and sex). This approach allows us to capture user-level
characteristics in the embedding space.
Processing Pipeline:
1. Load user metadata from XLS file (age, sex)
2. Textualization: Convert metadata to natural language description
3. SBERT encoding: Generate 384-dimensional embeddings
4. Visualization: t-SNE with continuous age coloring
Usage:
# Extract embeddings from metadata for all labels
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/all_labels
# Extract embeddings for a single label
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/REM --label REM
# Visualize with t-SNE from all label directories (colored by age)
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by age
# Visualize by sex from all labels
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by sex
# Create both age and sex plots from all labels
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by both
# Visualize from a single label directory
python gen_plot.py plot --emb_dir ./embeddings/REM --out_dir ./plots/REM --color_by age
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset, concatenate_datasets
from fire import Fire
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
# =============================================================================
# Constants
# =============================================================================
SUBJECT_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/SC-subjects.xls"
# =============================================================================
# Metadata Loading
# =============================================================================
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
"""
Load subject metadata from XLS file.
The XLS file contains subject information including:
- subject: Subject ID (e.g., "SC4001", "SC4002")
- age: Age of the subject
- sex (F=1): Sex (1 = Female, 0 = Male)
Args:
subject_path: Path to the SC-subjects.xls file
Returns:
Dictionary mapping subject IDs to metadata dictionaries
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
"""
df = pd.read_excel(subject_path)
subject_info = {}
for index, row in df.iterrows():
subject_id = str(row["subject"]).strip()
subject_info[subject_id] = {
"age": int(row["age"]) if pd.notna(row["age"]) else None,
"sex": int(row["sex (F=1)"]) if pd.notna(row["sex (F=1)"]) else None,
}
return subject_info
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class SBERT_Metadata:
"""
Extracts fixed-dimensional embeddings from user metadata using SBERT.
Uses Sentence-BERT to convert textualized metadata descriptions into dense
vector representations. The textualization process converts age and sex
information into natural language, which SBERT then encodes semantically.
Architecture:
User Metadata → Textualization → SBERT Encoder → Embedding
Processing Pipeline:
1. Load user metadata (age, sex) from XLS file
2. Textualize: Convert metadata to natural language description
3. SBERT encoding: Generate 384-dimensional semantic embeddings
4. Output: Fixed-size embedding vector per user
Model: all-MiniLM-L6-v2
- Lightweight BERT variant optimized for sentence embeddings
- 384-dimensional output embeddings
- Fast inference with good semantic understanding
Attributes:
model: SentenceTransformer instance for encoding textualized metadata
"""
def __init__(self):
"""
Initialize the SBERT embedder with pre-trained model.
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
optimized for sentence similarity tasks.
"""
self.model = SentenceTransformer("all-MiniLM-L6-v2")
@staticmethod
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
"""
Discover all user/session directories under data_root.
Uses glob pattern matching for cleaner directory traversal.
Expected structure: data_root/user_id/session_id/
Args:
data_root: Root directory containing user/session subfolders
Returns:
List of (user_id, session_id, session_path) tuples
"""
discovered_paths = []
# Use glob to find all session directories (2 levels deep)
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
if not os.path.isdir(session_path):
continue
# Extract user_id and session_id from path
session_id = os.path.basename(session_path)
user_id = os.path.basename(os.path.dirname(session_path))
discovered_paths.append((user_id, session_id, session_path))
return discovered_paths
def textualize_metadata(self, age: Optional[int], sex: Optional[int]) -> str:
"""
Convert user metadata (age and sex) into a natural language description.
This textualization step is crucial for SBERT, which expects text input.
The description provides user demographic information in a structured format.
Args:
age: Age of the user (integer, may be None)
sex: Sex of the user (0 = Male, 1 = Female, may be None)
Returns:
Natural language string describing the user metadata
"""
# Map sex code to text
if sex is not None:
sex_text = "Female" if sex == 1 else "Male"
else:
sex_text = "Unknown"
# Create sentence from metadata
if age is not None:
sentence = f"This is the information of the user, age: {age}, sex: {sex_text}."
else:
sentence = f"This is the information of the user, age: unknown, sex: {sex_text}."
return sentence
def compute_embedding_from_metadata(
self,
ages: List[Optional[int]],
sexes: List[Optional[int]]
) -> np.ndarray:
"""
Generate embedding vectors from metadata using SBERT.
Processing Pipeline:
1. Textualize each user's metadata into natural language
2. Encode textual descriptions using SBERT
3. Return fixed-size embedding vectors
Args:
ages: List of age values (may contain None)
sexes: List of sex values (0 = Male, 1 = Female, may contain None)
Returns:
Embedding array of shape (batch_size, embedding_dim)
For all-MiniLM-L6-v2: (batch_size, 384)
"""
# Convert metadata to text sentences
text_samples = []
for age, sex in zip(ages, sexes):
text_samples.append(self.textualize_metadata(age, sex))
# Encode text descriptions using SBERT
# Returns numpy array of shape (batch_size, 384)
embeddings = self.model.encode(text_samples)
return embeddings
def extract_embeddings(
self,
data_root: str,
subject_path: str = SUBJECT_PATH,
batch_size: int = 32,
label: Optional[str] = None
) -> Dataset:
"""
Extract embeddings from user metadata for all sessions.
Iterates through all user/session combinations, loads metadata for each user,
generates embeddings from metadata sentences, and aggregates results.
Can process a single label or all labels.
Args:
data_root: Root directory containing user/session data folders
subject_path: Path to SC-subjects.xls file with metadata
batch_size: Number of samples to process together (for batching embeddings)
Larger = faster but more memory.
32 is a good balance for most systems.
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM").
If None or "all", processes all labels.
Returns:
HuggingFace Dataset with columns:
- user_id, session_id, idx, label (metadata from original data)
- embedding (384-dim vector from all-MiniLM-L6-v2 based on metadata)
"""
# Load subject metadata
print(f"[INFO] Loading subject metadata from: {subject_path}")
subject_metadata = load_subject_metadata(subject_path)
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
session_paths = self.discover_session_paths(data_root)
print(f"[INFO] Discovered {len(session_paths)} sessions")
all_embeddings = []
all_user_ids = []
all_session_ids = []
all_idxs = []
all_labels = []
all_ages = []
all_sexes = []
# Collect metadata for all samples first
for user_id, session_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# Shuffle dataset for randomness
dataset = dataset.shuffle(seed=0)
# Filter by sleep stage label if specified
if label is not None and label != "all":
dataset = dataset.filter(lambda x: x["label"] == label)
num_samples = len(dataset)
if num_samples == 0:
continue
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Get metadata for this user
# Convert user_id to string format that matches metadata keys
user_id_str = str(int(user_id))
try:
age = subject_metadata[user_id_str]["age"]
sex = subject_metadata[user_id_str]["sex"]
except KeyError:
age = None
sex = None
print(f"[WARN] No metadata found for user_id: {user_id_str}")
# Collect all samples for this user/session
for i in range(num_samples):
all_user_ids.append(str(dataset["user_id"][i]))
all_session_ids.append(str(dataset["session_id"][i]))
all_idxs.append(int(dataset["idx"][i]))
all_labels.append(str(dataset["label"][i]))
all_ages.append(age)
all_sexes.append(sex)
# Generate embeddings from metadata in batches
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
for batch_start in range(0, len(all_ages), batch_size):
batch_end = min(batch_start + batch_size, len(all_ages))
batch_ages = all_ages[batch_start:batch_end]
batch_sexes = all_sexes[batch_start:batch_end]
# Compute embeddings from metadata
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
# Collect embeddings
for i in range(embeddings.shape[0]):
all_embeddings.append(embeddings[i].tolist())
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
"session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"age": all_ages,
"sex": all_sexes,
"embedding": all_embeddings,
})
print(f"[INFO] Total samples: {len(result_dataset)}")
# Print label distribution
if "label" in result_dataset.column_names:
label_counts = {}
for lbl in result_dataset["label"]:
label_counts[lbl] = label_counts.get(lbl, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return result_dataset
def save_embeddings(
self,
dataset: Dataset,
output_dir: str
) -> None:
"""
Save embeddings dataset to disk in HuggingFace format.
Args:
dataset: HuggingFace Dataset containing embeddings and metadata
output_dir: Directory path to save the dataset
"""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load saved embeddings dataset from disk.
Args:
embedding_dir: Directory path containing the saved HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Data Loading Utilities
# =============================================================================
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
"""
Load embeddings from all label subdirectories and concatenate them.
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
and loads embeddings from each, then concatenates them into a single dataset.
Args:
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
Returns:
Concatenated HuggingFace Dataset with all labels combined
"""
# Discover all label subdirectories
label_dirs = []
for item in os.listdir(embeddings_root):
item_path = os.path.join(embeddings_root, item)
if os.path.isdir(item_path):
# Check if it's a valid HuggingFace dataset directory
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
label_dirs.append((item, item_path))
if len(label_dirs) == 0:
raise ValueError(
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
)
label_dirs.sort() # Sort for consistent ordering
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
# Load datasets from each label directory
datasets = []
for label_name, label_path in label_dirs:
print(f"[INFO] Loading embeddings from: {label_path}")
dataset = load_from_disk(label_path)
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return combined_dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
dimensionality reduction technique that preserves local structure.
Points that are similar in high dimensions stay close in 2D.
Args:
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity parameter (typically 5-50).
Higher values consider more neighbors, creating smoother layouts.
Rule of thumb: perplexity ~ sqrt(num_samples)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0, # For reproducibility
perplexity=perplexity,
max_iter=1000, # Usually sufficient for convergence
init='random',
learning_rate='auto', # Let sklearn choose optimal learning rate
)
return tsne.fit_transform(embeddings)
def create_scatter_plot_by_age(
coordinates: np.ndarray,
ages: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by age (continuous colormap).
Uses a continuous colormap (viridis) to show age distribution.
Ages are mapped to colors on a gradient scale.
Args:
coordinates: 2D array of shape (num_points, 2)
ages: Age values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing age data
# Convert None values to NaN for proper numpy handling
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
valid_mask = ~np.isnan(ages_float)
valid_coords = coordinates[valid_mask]
valid_ages = ages_float[valid_mask]
if len(valid_ages) == 0:
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
return
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Create scatter plot with continuous colormap
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_ages,
cmap='viridis', # Continuous colormap for granular age visualization
s=15,
alpha=0.7,
edgecolors='none',
)
# Add colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Age (years)', rotation=270, labelpad=20)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
def create_scatter_plot_by_sex(
coordinates: np.ndarray,
sexes: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by sex (categorical).
Uses discrete colors for different sex categories.
Sex encoding: 1 = Female, 0 = Male
Args:
coordinates: 2D array of shape (num_points, 2)
sexes: Sex values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing sex data
# Convert None values to NaN for proper numpy handling
sexes_float = np.array([float(s) if s is not None else np.nan for s in sexes])
valid_mask = ~np.isnan(sexes_float)
valid_coords = coordinates[valid_mask]
valid_sexes = sexes_float[valid_mask].astype(int)
if len(valid_sexes) == 0:
print(f"[WARN] No valid sex data found. Skipping plot: {output_path}")
return
# Map sex codes to labels
sex_labels = {0: "Male", 1: "Female"}
unique_sexes = sorted(set(valid_sexes))
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for proper legend
colors = ['steelblue', 'coral'] # Blue for Male, Coral for Female
for idx, sex_code in enumerate(unique_sexes):
mask = valid_sexes == sex_code
ax.scatter(
valid_coords[mask, 0],
valid_coords[mask, 1],
c=colors[sex_code % len(colors)],
s=15,
label=sex_labels.get(sex_code, f"Sex {sex_code}"),
alpha=0.7,
)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(loc='best', markerscale=2)
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
sex_counts = {sex_labels.get(s, f"Sex {s}"): (valid_sexes == s).sum() for s in unique_sexes}
print(f"[INFO] Sex distribution: {sex_counts}")
print(f"[INFO] Points with valid sex: {len(valid_sexes)}/{len(sexes)}")
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT metadata-based embedding extraction and visualization.
Provides two main commands:
- extract: Generate embeddings from user metadata (age, sex)
- plot: Visualize embeddings with t-SNE colored by age or sex
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
subject_path: str = SUBJECT_PATH,
out_dir: str = "./embeddings/all_labels",
batch_size: int = 32,
label: str = None
) -> None:
"""
Extract embeddings from user metadata.
Args:
data_root: Root directory containing user/session data folders
subject_path: Path to SC-subjects.xls file with metadata
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
If None or 'all', processes all labels (default: None for all labels)
"""
embedder = SBERT_Metadata()
dataset = embedder.extract_embeddings(data_root, subject_path, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = None,
embeddings_root: str = "./embeddings",
out_dir: str = "./plots/all_labels",
perplexity: float = 10.0,
color_by: str = "age",
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE, colored by age or sex.
Can load from either a single label directory or all label directories.
Args:
emb_dir: Single directory containing the HuggingFace embeddings dataset
(e.g., "./embeddings/REM"). If provided, only this directory is used.
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
Used only if emb_dir is not provided.
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 10.0)
color_by: What to color by - 'age', 'sex', or 'both' (default: 'age')
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
This filters the already-loaded data, not which directories to load.
"""
# Validate color_by argument
if color_by not in ["age", "sex", "both", "all"]:
raise ValueError(f"Invalid color_by: {color_by}. Use 'age', 'sex', or 'both'.")
os.makedirs(out_dir, exist_ok=True)
# Load embeddings: either from single directory or all label directories
if emb_dir is not None:
# Load from single directory
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
dataset = SBERT_Metadata.load_embeddings(emb_dir)
else:
# Load from all label directories
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
dataset = load_embeddings_from_all_labels(embeddings_root)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Extract ages and sexes for coloring
ages = np.array(dataset["age"])
sexes = np.array(dataset["sex"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations based on color_by parameter
if color_by == "age":
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
elif color_by == "sex":
create_scatter_plot_by_sex(
coordinates_2d,
sexes,
"t-SNE Visualization (Colored by Sex)",
os.path.join(out_dir, "tsne_by_sex.pdf")
)
elif color_by == "both" or color_by == "all":
# Create both plots
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
create_scatter_plot_by_sex(
coordinates_2d,
sexes,
"t-SNE Visualization (Colored by Sex)",
os.path.join(out_dir, "tsne_by_sex.pdf")
)
if __name__ == "__main__":
Fire(CLI)

View File

@@ -1,748 +0,0 @@
"""
SBERT Metadata-Based Embedding Extraction and Visualization Pipeline
This module provides functionality to
1. Extract embeddings from user metadata (age and sex) using SBERT (Sentence-BERT)
2. Visualize embeddings using dimensionality reduction (t-SNE) colored by age
SBERT Overview:
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
and triplet network structures to derive semantically meaningful sentence
embeddings. It's designed for semantic similarity tasks and produces
fixed-size dense vector representations of text.
Key Features:
- Semantic understanding: Captures meaning rather than just word presence
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
- Efficient: Optimized for sentence-level tasks
Embedding Strategy:
Instead of using time series features, we create textual descriptions based on
user metadata (age and sex). This approach allows us to capture user-level
characteristics in the embedding space.
Processing Pipeline:
1. Load user metadata from XLS file (age, sex)
2. Textualization: Convert metadata to natural language description
3. SBERT encoding: Generate 384-dimensional embeddings
4. Visualization: t-SNE with continuous age coloring
Usage:
# Extract embeddings from metadata for all labels
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/all_labels
# Extract embeddings for a single label
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/REM --label REM
# Visualize with t-SNE from all label directories (colored by age)
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels
# Visualize from a single label directory
python gen_plot.py plot --emb_dir ./embeddings/REM --out_dir ./plots/REM
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset, concatenate_datasets
from fire import Fire
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
# =============================================================================
# Constants
# =============================================================================
SUBJECT_PATH = "LLM_Health/tsllm_personalization_icl/analysis/user_similarity/sbert_metadata_ppgbp/PPGBP_metadata.xlsx"
# =============================================================================
# Metadata Loading
# =============================================================================
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
"""
Load subject metadata from XLS file.
The XLS file contains subject information including:
- subject: Subject ID (e.g., "SC4001", "SC4002")
- age: Age of the subject
- sex (F=1): Sex (1 = Female, 0 = Male)
Args:
subject_path: Path to the SC-subjects.xls file
Returns:
Dictionary mapping subject IDs to metadata dictionaries
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
"""
df = pd.read_excel(subject_path, header=1)
subject_info = {}
for index, row in df.iterrows():
subject_id = str(row["subject_ID"]).strip()
subject_info[subject_id] = {
"sex": str(row["Sex(M/F)"]).strip() if pd.notna(row["Sex(M/F)"]) else None,
"age": int(row["Age(year)"]) if pd.notna(row["Age(year)"]) else None,
"height": int(row["Height(cm)"]) if pd.notna(row["Height(cm)"]) else None,
"weight": int(row["Weight(kg)"]) if pd.notna(row["Weight(kg)"]) else None,
"sbp": int(row["Systolic Blood Pressure(mmHg)"])
if pd.notna(row["Systolic Blood Pressure(mmHg)"]) else None,
"dbp": int(row["Diastolic Blood Pressure(mmHg)"])
if pd.notna(row["Diastolic Blood Pressure(mmHg)"]) else None,
"hr": int(row["Heart Rate(b/m)"])
if pd.notna(row["Heart Rate(b/m)"]) else None,
"bmi": float(row["BMI(kg/m^2)"]) if pd.notna(row["BMI(kg/m^2)"]) else None,
"hypertension": str(row["Hypertension"]) if pd.notna(row["Hypertension"]) else None,
}
return subject_info
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class SBERT_Metadata:
"""
Extracts fixed-dimensional embeddings from user metadata using SBERT.
Uses Sentence-BERT to convert textualized metadata descriptions into dense
vector representations. The textualization process converts sex, age, height,
weight, sbp, dbp, heart rate, bmi, hypertension information into natural language,
which SBERT then encodes semantically.
Architecture:
User Metadata → Textualization → SBERT Encoder → Embedding
Processing Pipeline:
1. Load user metadata (sex, age, height, weight, sbp, dbp, heart rate, bmi,
hypertension) from XLS file
2. Textualize: Convert metadata to natural language description
3. SBERT encoding: Generate 384-dimensional semantic embeddings
4. Output: Fixed-size embedding vector per user
Model: all-MiniLM-L6-v2
- Lightweight BERT variant optimized for sentence embeddings
- 384-dimensional output embeddings
- Fast inference with good semantic understanding
Attributes:
model: SentenceTransformer instance for encoding textualized metadata
"""
def __init__(self):
"""
Initialize the SBERT embedder with pre-trained model.
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
optimized for sentence similarity tasks.
"""
self.model = SentenceTransformer("all-MiniLM-L6-v2")
def textualize_metadata_ppg_bp(self,
sex: Optional[str],
age: Optional[int],
height: Optional[int],
weight: Optional[int],
sbp: Optional[int],
dbp: Optional[int],
heart_rate: Optional[int],
bmi: Optional[float],
hypertension: Optional[str]) -> str:
"""
Convert user metadata (age, sex, height, weight, sbp, dbp, heart rate, bmi, hypertension) into a natural language description.
This textualization step is crucial for SBERT, which expects text input.
The description provides physiological and demographic information in a structured format.
Args:
sex: Sex of the user (0 = Male, 1 = Female, may be None)
age: Age of the user (integer, may be None)
height_cm: Height in centimeters, (integer, may be None)
weight_kg: Weight in kilograms, (integer, may be None)
sbp_mmHg: Systolic blood pressure in mmHg, (integer, may be None)
dbp_mmHg: Diastolic blood pressure in mmHg, (integer, may be None)
heart_rate_bpm: Heart rate in beats per minute, (integer, may be None)
bmi: Body mass index (kg/m^2), (float, may be None)
hypertension: Hypertension status (0 = Normal, 1 = Prehypertension, 2 = Stage 1 hypertension, 3 = Stage 2 hypertension, may be None)
Returns:
Natural language string describing the user metadata
"""
# Map sex code to text
if sex is not None:
sex_text = "Female" if sex == 1 else "Male"
else:
sex_text = "Unknown"
# Map age code to text
if age is not None:
age_text = f"{age}"
else:
age_text = "unknown"
# Map height code to text
if height is not None:
height_text = f"{height} cm"
else:
height_text = "unknown"
# Map weight code to text
if weight is not None:
weight_text = f"{weight} kg"
else:
weight_text = "unknown"
# Map sbp code to text
if sbp is not None:
sbp_text = f"{sbp} mmHg"
else:
sbp_text = "unknown"
# Map dbp code to text
if dbp is not None:
dbp_text = f"{dbp} mmHg"
else:
dbp_text = "unknown"
# Map heart rate code to text
if heart_rate is not None:
heart_rate_text = f"{heart_rate} bpm"
else:
heart_rate_text = "unknown"
# Map bmi code to text
if bmi is not None:
bmi_text = f"{bmi} kg/m^2"
else:
bmi_text = "unknown"
# Map hypertension code to text
if hypertension is not None:
hypertension_text = f"{hypertension}"
else:
hypertension_text = "unknown"
# Create sentence from metadata
if age is not None:
sentence = f"This is the information of the user, sex: {sex_text}, age: {age_text}, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
else:
sentence = f"This is the information of the user, sex: {sex_text}, age: unknown, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
return sentence
def compute_embedding_from_metadata(self,
sexes: List[Optional[int]],
ages: List[Optional[int]],
heights: List[Optional[float]],
weights: List[Optional[float]],
systolics: List[Optional[float]],
diastolics: List[Optional[float]],
heart_rates: List[Optional[float]],
bmis: List[Optional[float]],
hypertensions: List[Optional[int]]) -> np.ndarray:
"""
Generate embedding vectors from metadata using SBERT.
Processing Pipeline:
1. Textualize each user's metadata into natural language
2. Encode textual descriptions using SBERT
3. Return fixed-size embedding vectors
Args:
sexes: List of sex values (0 = Male, 1 = Female, may contain None)
ages: List of age values (may contain None)
heights: List of height values (may contain None)
weights: List of weight values (may contain None)
systolics: List of systolic blood pressure values (may contain None)
diastolics: List of diastolic blood pressure values (may contain None)
heart_rates: List of heart rate values (may contain None)
bmis: List of body mass index values (may contain None)
hypertensions: List of hypertension values (0 = Normal, 1 = Prehypertension,
2 = Stage 1 hypertension, 3 = Stage 2 hypertension, may contain None)
Returns:
Embedding array of shape (batch_size, embedding_dim)
For all-MiniLM-L6-v2: (batch_size, 384)
"""
# Convert metadata to text sentences
text_samples = []
for sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension in zip(sexes, ages, heights, weights, systolics, diastolics, heart_rates, bmis, hypertensions):
text_samples.append(self.textualize_metadata_ppg_bp(sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension))
# Encode text descriptions using SBERT
# Returns numpy array of shape (batch_size, 384)
embeddings = self.model.encode(text_samples)
return embeddings
def extract_embeddings(
self,
data_root: str,
subject_path: str = SUBJECT_PATH,
batch_size: int = 32,
label: Optional[str] = None
) -> Dataset:
"""
Extract embeddings from user metadata for all sessions.
Iterates through all user/session combinations, loads metadata for each user,
generates embeddings from metadata sentences, and aggregates results.
Can process a single label or all labels.
Args:
data_root: Root directory containing user/session data folders
subject_path: Path to SC-subjects.xls file with metadata
batch_size: Number of samples to process together (for batching embeddings)
Larger = faster but more memory.
32 is a good balance for most systems.
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM").
If None or "all", processes all labels.
Returns:
HuggingFace Dataset with columns:
- user_id, idx, label (metadata from original data)
- embedding (384-dim vector from all-MiniLM-L6-v2 based on metadata)
"""
# Load subject metadata
print(f"[INFO] Loading subject metadata from: {subject_path}")
subject_metadata = load_subject_metadata(subject_path)
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
all_embeddings = []
all_user_ids = []
# all_session_ids = []
all_idxs = []
all_labels = []
all_sexes = []
all_ages = []
all_heights = []
all_weights = []
all_systolics = []
all_diastolics = []
all_heart_rates = []
all_bmis = []
all_hypertensions = []
# Collect metadata for all samples first
for user_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# Shuffle dataset for randomness
dataset = dataset.shuffle(seed=0)
# Filter by sleep stage label if specified
if label is not None and label != "all":
dataset = dataset.filter(lambda x: x["label"] == label)
num_samples = len(dataset)
if num_samples == 0:
continue
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Get metadata for this user
# Convert user_id to string format that matches metadata keys
user_id_str = str(int(user_id))
try:
sex = subject_metadata[user_id_str]["sex"]
age = subject_metadata[user_id_str]["age"]
height = subject_metadata[user_id_str]["height"]
weight = subject_metadata[user_id_str]["weight"]
sbp = subject_metadata[user_id_str]["sbp"]
dbp = subject_metadata[user_id_str]["dbp"]
heart_rate = subject_metadata[user_id_str]["heart_rate"]
bmi = subject_metadata[user_id_str]["bmi"]
hypertension = subject_metadata[user_id_str]["hypertension"]
except KeyError:
sex = None
age = None
height = None
weight = None
sbp = None
dbp = None
heart_rate = None
bmi = None
hypertension = None
print(f"[WARN] No metadata found for user_id: {user_id_str}")
# Collect all samples for this user/session
for i in range(num_samples):
all_user_ids.append(str(dataset["user_id"][i]))
all_session_ids.append(str(dataset["session_id"][i]))
all_idxs.append(int(dataset["idx"][i]))
all_labels.append(str(dataset["label"][i]))
all_ages.append(age)
all_sexes.append(sex)
# Generate embeddings from metadata in batches
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
for batch_start in range(0, len(all_ages), batch_size):
batch_end = min(batch_start + batch_size, len(all_ages))
batch_sexes = all_sexes[batch_start:batch_end]
batch_ages = all_ages[batch_start:batch_end]
# Compute embeddings from metadata
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
# Collect embeddings
for i in range(embeddings.shape[0]):
all_embeddings.append(embeddings[i].tolist())
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
# "session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"sex": all_sexes,
"age": all_ages,
"height": all_heights,
"weight": all_weights,
"systolic": all_systolics,
"diastolic": all_diastolics,
"heart_rate": all_heart_rates,
"bmi": all_bmis,
"hypertension": all_hypertensions,
"embedding": all_embeddings,
})
print(f"[INFO] Total samples: {len(result_dataset)}")
# Print label distribution
if "label" in result_dataset.column_names:
label_counts = {}
for lbl in result_dataset["label"]:
label_counts[lbl] = label_counts.get(lbl, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return result_dataset
def save_embeddings(
self,
dataset: Dataset,
output_dir: str
) -> None:
"""
Save embeddings dataset to disk in HuggingFace format.
Args:
dataset: HuggingFace Dataset containing embeddings and metadata
output_dir: Directory path to save the dataset
"""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load saved embeddings dataset from disk.
Args:
embedding_dir: Directory path containing the saved HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Data Loading Utilities
# =============================================================================
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
"""
Load embeddings from all label subdirectories and concatenate them.
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
and loads embeddings from each, then concatenates them into a single dataset.
Args:
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
Returns:
Concatenated HuggingFace Dataset with all labels combined
"""
# Discover all label subdirectories
label_dirs = []
for item in os.listdir(embeddings_root):
item_path = os.path.join(embeddings_root, item)
if os.path.isdir(item_path):
# Check if it's a valid HuggingFace dataset directory
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
label_dirs.append((item, item_path))
if len(label_dirs) == 0:
raise ValueError(
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
)
label_dirs.sort() # Sort for consistent ordering
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
# Load datasets from each label directory
datasets = []
for label_name, label_path in label_dirs:
print(f"[INFO] Loading embeddings from: {label_path}")
dataset = load_from_disk(label_path)
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return combined_dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
dimensionality reduction technique that preserves local structure.
Points that are similar in high dimensions stay close in 2D.
Args:
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity parameter (typically 5-50).
Higher values consider more neighbors, creating smoother layouts.
Rule of thumb: perplexity ~ sqrt(num_samples)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0, # For reproducibility
perplexity=perplexity,
max_iter=1000, # Usually sufficient for convergence
init='random',
learning_rate='auto', # Let sklearn choose optimal learning rate
)
return tsne.fit_transform(embeddings)
def create_scatter_plot_by_age(
coordinates: np.ndarray,
ages: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by age (continuous colormap).
Uses a continuous colormap (viridis) to show age distribution.
Ages are mapped to colors on a gradient scale.
Args:
coordinates: 2D array of shape (num_points, 2)
ages: Age values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing age data
# Convert None values to NaN for proper numpy handling
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
valid_mask = ~np.isnan(ages_float)
valid_coords = coordinates[valid_mask]
valid_ages = ages_float[valid_mask]
if len(valid_ages) == 0:
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
return
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Create scatter plot with continuous colormap
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_ages,
cmap='viridis', # Continuous colormap for granular age visualization
s=15,
alpha=0.7,
edgecolors='none',
)
# Add colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Age (years)', rotation=270, labelpad=20)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT metadata-based embedding extraction and visualization.
Provides two main commands:
- extract: Generate embeddings from user metadata (age, sex)
- plot: Visualize embeddings with t-SNE colored by age
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
subject_path: str = SUBJECT_PATH,
out_dir: str = "./embeddings/all_labels",
batch_size: int = 32,
label: str = None
) -> None:
"""
Extract embeddings from user metadata.
Args:
data_root: Root directory containing user/session data folders
subject_path: Path to SC-subjects.xls file with metadata
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
If None or 'all', processes all labels (default: None for all labels)
"""
embedder = SBERT_Metadata()
dataset = embedder.extract_embeddings(data_root, subject_path, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = None,
embeddings_root: str = "./embeddings",
out_dir: str = "./plots/all_labels",
perplexity: float = 10.0,
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE, colored by age.
Can load from either a single label directory or all label directories.
Args:
emb_dir: Single directory containing the HuggingFace embeddings dataset
(e.g., "./embeddings/REM"). If provided, only this directory is used.
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
Used only if emb_dir is not provided.
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 10.0)
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
This filters the already-loaded data, not which directories to load.
"""
os.makedirs(out_dir, exist_ok=True)
# Load embeddings: either from single directory or all label directories
if emb_dir is not None:
# Load from single directory
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
dataset = SBERT_Metadata.load_embeddings(emb_dir)
else:
# Load from all label directories
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
dataset = load_embeddings_from_all_labels(embeddings_root)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Extract ages for coloring
ages = np.array(dataset["age"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualization colored by age
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
if __name__ == "__main__":
Fire(CLI)

58
baselines/common.py Normal file
View File

@@ -0,0 +1,58 @@
"""Shared setup for baseline experiments."""
import os
import sys
import yaml
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)
from core.data_loader import DataLoader # noqa: E402
from core.model import load_models # noqa: E402
from core.recruiter import Recruiter # noqa: E402
from core.agent import Agent # noqa: E402
from core.logger import Logger # noqa: E402
from core.prompt import gen_system_message, gen_task_message # noqa: E402
from core.json_utils import safe_parse_json # noqa: E402
from core.scores import self_certainty # noqa: E402
from core.vote import borda_vote # noqa: E402
def load_config(config_path: str) -> dict:
with open(config_path, "r", encoding="utf-8") as f:
return yaml.load(f, Loader=yaml.SafeLoader)
def setup(config: dict, temperature: float = None):
"""
Initialize all components from config.
Args:
config: Parsed YAML config dict.
temperature: Override model temperature (e.g. 0.7 for SC baseline).
If None, uses config["temperature"] or 0.0.
Returns:
(logger, dataloader, recruiter, agent)
"""
logger = Logger(config.get("log_path"))
logger.log_config(config)
dataloader = DataLoader(config.get("data_path"), config.get("target_user"))
recruiter = Recruiter(
source_dataset=dataloader.get_source_dataset(),
source_users=dataloader.get_source_users(),
num_shot=config.get("num_shot"),
classes=dataloader.get_classes(),
logger=logger,
)
temp = temperature if temperature is not None else config.get("temperature", 0.0)
model_pool = load_models(config.get("model_paths"), temperature=temp)
agent = Agent(model_pool=model_pool, logger=logger)
system_message = gen_system_message(metadata=dataloader.get_task_metadata())
agent.set_system_message(system_message)
return logger, dataloader, recruiter, agent

View File

@@ -0,0 +1,72 @@
"""
Baseline 4: Random Examples + Borda Voting
For each sample, randomly recruit queue_size fresh example sets.
Each example set produces one inference. Answers aggregated via borda voting.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config, setup,
gen_task_message, safe_parse_json, self_certainty, borda_vote,
)
async def run(config_path: str):
config = load_config(config_path)
queue_size = config.get("queue_size")
logger, dataloader, recruiter, agent = setup(config)
logger.log("[Baseline] Random examples + borda voting")
async def process(sample_idx, example_idx, example_set, sample):
try:
task_message = gen_task_message(sample, example_set)
response, logprobs = await agent.solve(
task_message, sample_idx, example_idx
)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(
f"[Main] Done {sample_idx} - {example_idx}: "
f"answer={answer}, score={score:.4f}"
)
return {"answer": answer, "score": score}
except Exception as e:
logger.log(
f"[Main] Error {sample_idx} - {example_idx}: {e}",
filename="errors.txt",
)
return {"answer": None, "score": float("-inf")}
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
example_sets = recruiter.recruit(queue_size)
tasks = [
process(idx, i, es, sample) for i, es in enumerate(example_sets)
]
results = await asyncio.gather(*tasks)
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
ground_truth = sample["label"]
if winner is not None:
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
logger.log(f"[Vote] votes={{ {tally_str} }}")
logger.log_result(idx, winner, ground_truth)
else:
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,69 @@
"""
Baseline 6: Random Dynamic Self-Consistency
For each sample, randomly recruit ONE fresh example set,
then run that same prompt queue_size times with temperature=0.7.
Answers are aggregated via borda voting.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config, setup,
gen_task_message, safe_parse_json, self_certainty, borda_vote,
)
async def run(config_path: str):
config = load_config(config_path)
queue_size = config.get("queue_size")
logger, dataloader, recruiter, agent = setup(config, temperature=0.7)
logger.log("[Baseline] Random dynamic example + self-consistency (temp=0.7)")
async def run_once(sample_idx, run_idx, task_message):
try:
response, logprobs = await agent.solve(task_message, sample_idx, run_idx)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(
f"[SC] {sample_idx} - run {run_idx}: "
f"answer={answer}, score={score:.4f}"
)
return {"answer": answer, "score": score}
except Exception as e:
logger.log(
f"[SC] Error {sample_idx} - run {run_idx}: {e}",
filename="errors.txt",
)
return {"answer": None, "score": float("-inf")}
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
example_set = recruiter.recruit(1)[0]
task_message = gen_task_message(sample, example_set)
tasks = [run_once(idx, i, task_message) for i in range(queue_size)]
results = await asyncio.gather(*tasks)
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
ground_truth = sample["label"]
if winner is not None:
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
logger.log(f"[Vote] votes={{ {tally_str} }}")
logger.log_result(idx, winner, ground_truth)
else:
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,51 @@
"""
Baseline 2: Random Dynamic Single Example
For each sample, randomly select a fresh example set.
Single inference per sample, no voting.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config, setup,
gen_task_message, safe_parse_json, self_certainty,
)
async def run(config_path: str):
config = load_config(config_path)
logger, dataloader, recruiter, agent = setup(config)
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
ground_truth = sample["label"]
try:
example_set = recruiter.recruit(1)[0]
task_message = gen_task_message(sample, example_set)
response, logprobs = await agent.solve(task_message, idx, 0)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(f"[Main] answer={answer}, score={score:.4f}")
except Exception as e:
logger.log(f"[Main] Error sample {idx}: {e}", filename="errors.txt")
answer = None
if answer is not None:
logger.log_result(idx, answer, ground_truth)
else:
logger.log(f"[Main] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,75 @@
"""
Baseline 5: Fixed Examples + Borda Voting
Recruit queue_size example sets once at the start.
For each sample, run all example sets through the LLM.
Answers aggregated via borda voting. No queue updates.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config,
setup,
gen_task_message,
safe_parse_json,
self_certainty,
borda_vote,
)
async def run(config_path: str):
config = load_config(config_path)
queue_size = config.get("queue_size")
logger, dataloader, recruiter, agent = setup(config)
example_sets = recruiter.recruit(queue_size)
logger.log(f"[Baseline] Fixed {queue_size} example set(s) + borda voting")
async def process(sample_idx, example_idx, example_set, sample):
try:
task_message = gen_task_message(sample, example_set)
response, logprobs = await agent.solve(
task_message, sample_idx, example_idx
)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(
f"[Main] Done {sample_idx} - {example_idx}: "
f"answer={answer}, score={score:.4f}"
)
return {"answer": answer, "score": score}
except Exception as e:
logger.log(
f"[Main] Error {sample_idx} - {example_idx}: {e}",
filename="errors.txt",
)
return {"answer": None, "score": float("-inf")}
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
tasks = [process(idx, i, es, sample) for i, es in enumerate(example_sets)]
results = await asyncio.gather(*tasks)
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
ground_truth = sample["label"]
if winner is not None:
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
logger.log(f"[Vote] votes={{ {tally_str} }}")
logger.log_result(idx, winner, ground_truth)
else:
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,69 @@
"""
Baseline 3: Self-Consistency with Fixed Example
One randomly selected example set, used for ALL samples.
For each sample, the SAME prompt is run queue_size times with temperature=0.7
across different LLM instances. Answers are aggregated via borda voting.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config, setup,
gen_task_message, safe_parse_json, self_certainty, borda_vote,
)
async def run(config_path: str):
config = load_config(config_path)
queue_size = config.get("queue_size")
logger, dataloader, recruiter, agent = setup(config, temperature=0.7)
example_set = recruiter.recruit(1)[0]
logger.log("[Baseline] Fixed example set + self-consistency (temp=0.7)")
async def run_once(sample_idx, run_idx, task_message):
try:
response, logprobs = await agent.solve(task_message, sample_idx, run_idx)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(
f"[SC] {sample_idx} - run {run_idx}: "
f"answer={answer}, score={score:.4f}"
)
return {"answer": answer, "score": score}
except Exception as e:
logger.log(
f"[SC] Error {sample_idx} - run {run_idx}: {e}",
filename="errors.txt",
)
return {"answer": None, "score": float("-inf")}
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
task_message = gen_task_message(sample, example_set)
tasks = [run_once(idx, i, task_message) for i in range(queue_size)]
results = await asyncio.gather(*tasks)
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
ground_truth = sample["label"]
if winner is not None:
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
logger.log(f"[Vote] votes={{ {tally_str} }}")
logger.log_result(idx, winner, ground_truth)
else:
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,53 @@
"""
Baseline 1: Fixed Single Example
One randomly selected example set, used for ALL samples.
Single inference per sample, no voting.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config, setup,
gen_task_message, safe_parse_json, self_certainty,
)
async def run(config_path: str):
config = load_config(config_path)
logger, dataloader, recruiter, agent = setup(config)
example_set = recruiter.recruit(1)[0]
logger.log("[Baseline] Fixed single example set selected")
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
ground_truth = sample["label"]
try:
task_message = gen_task_message(sample, example_set)
response, logprobs = await agent.solve(task_message, idx, 0)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(f"[Main] answer={answer}, score={score:.4f}")
except Exception as e:
logger.log(f"[Main] Error sample {idx}: {e}", filename="errors.txt")
answer = None
if answer is not None:
logger.log_result(idx, answer, ground_truth)
else:
logger.log(f"[Main] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/ours/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/ours/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/ours/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_borda/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_borda/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_borda/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_sc/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_sc/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_sc/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_single/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_single/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_single/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_borda/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_borda/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_borda/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_sc/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_sc/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_sc/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_single/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_single/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_single/sleepedf/02

View File

@@ -1,32 +0,0 @@
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
num_seeds: 10
num_examples: 1
# Selection criteria: out_random | in_random | out_similar
selection_criteria: "out_random"
# Embedding path (not required for random criteria)
# embedding_path: "/home/ssum/tsllm_personalization_icl/embeddings_20users"
models:
- ollama:url:joy.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11444/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
# chronos_base : "/mnt/sting/ssum/sleepedf_chronos_base_result"
# out_random : "/mnt/sting/ssum/sleepedf_chronos_result_outrandom"
# in_random : "/mnt/sting/ssum/sleepedf_chronos_result"
log_path: "/mnt/sting/ssum/sleepedf_chronos_result"

16
config/test.yaml Normal file
View File

@@ -0,0 +1,16 @@
log_path: ./temp/log
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: "00"
queue_size: 5
num_shot: 1
model_paths:
- ollama:url:rose.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11444/gpt-oss:20b
update_size: 1
vocab_size: 200064

View File

@@ -1,208 +1,39 @@
import os
import re
import json
import tiktoken
from langchain_ollama import ChatOllama
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from core.logger import Logger
from core.model import AsyncModelPool
class Agent:
def __init__(
self,
name,
model_pool,
log_path,
model_pool: AsyncModelPool,
logger: Logger,
):
self.name = name
self.model_pool = model_pool
self.log_path = log_path
self.root_log_path = log_path
self.agent_log_path = os.path.join(log_path, name)
os.makedirs(self.agent_log_path, exist_ok=True)
self.logger = logger
self.system_message = None
self.long_term_memory = []
self.short_term_memory = []
self.volatile_memory = []
def set_system_message(self, system_message: str):
self.system_message = system_message
log = f"[SYSTEM]\n{system_message}\n\n"
log_filename = "llm_log/system_prompt.txt"
self.logger.log(log, log_filename, print_log=False)
self.total_input_tokens = 0
self.total_output_tokens = 0
self.total_tokens = 0
self.total_calls = 0
def log(self, message, local=True):
path = os.path.join(self.root_log_path, "log.txt")
with open(path, "a", encoding="utf-8") as f:
message_type = "UNKNOWN"
if isinstance(message, SystemMessage):
message_type = "SYSTEM"
if isinstance(message, HumanMessage):
message_type = "HUMAN"
if isinstance(message, AIMessage):
message_type = "AI"
content = message.content.strip()
name = self.name
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
if local:
local_path = os.path.join(self.agent_log_path, "log.txt")
with open(local_path, "a", encoding="utf-8") as f:
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
def log_tokens(self, messages, response):
input_tokens = 0
for msg in messages:
msg_tokens = self.count_tokens(msg.content)
input_tokens += msg_tokens
self.total_tokens += msg_tokens
self.total_input_tokens += msg_tokens
output_tokens = self.count_tokens(response.content)
self.total_tokens += output_tokens
self.total_output_tokens += output_tokens
self.total_calls += 1
path = os.path.join(self.agent_log_path, "tokens.txt")
with open(path, "a", encoding="utf-8") as f:
f.write(f"Input tokens: {input_tokens}\n")
f.write(f"Output tokens: {output_tokens}\n")
f.write(f"Total input tokens: {self.total_input_tokens}\n")
f.write(f"Total output tokens: {self.total_output_tokens}\n")
f.write(f"Total tokens: {self.total_tokens}\n")
f.write(f"Total calls: {self.total_calls}\n")
f.write("\n")
def count_tokens(self, text, model="gpt-3.5-turbo"):
enc = tiktoken.encoding_for_model(model)
return len(enc.encode(text))
def update_memory(self):
self.long_term_memory.extend(self.short_term_memory)
self.clean_short_term_memory()
self.clean_volatile_memory()
def clean_short_term_memory(self):
self.short_term_memory = []
def clean_volatile_memory(self):
self.volatile_memory = []
def clean_long_term_memory(self):
self.long_term_memory = []
def clean_json_text(self, text):
text = text.strip()
text = text.replace("", "'").replace("", "'")
text = text.replace("", "'").replace("", "'")
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text)
text = re.sub(r",\s*}", "}", text)
text = re.sub(r",\s*]", "]", text)
text = "".join(ch for ch in text if ch.isprintable())
text = text.replace("][", ",")
return text
def safe_parse_json(self, text):
text = text.strip()
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
elif not text.endswith("}"):
text += "}"
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
print("[!] JSON parse failed")
return None
def safe_parse_json_list(self, text):
text = text.strip()
match = re.search(r"\[.*\]", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
elif not text.endswith("]"):
text += "]"
match = re.search(r"\[.*\]", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
print("[!] JSON parse failed")
return None
async def validate_response(self, response, fields, volatile=False):
if (
not response
or not isinstance(response, dict)
or not all(field in response for field in fields)
):
print("[!] The JSON failed to be parsed. Trying again.")
content = (
"Failed to parse the JSON from the previous response. Please try again."
)
response = await self.invoke(content, volatile=volatile)
response = self.safe_parse_json(response)
if (
not response
or not isinstance(response, dict)
or not all(field in response for field in fields)
):
print("[!] Retry failed.")
return None
return response
def get_last_response(self):
if len(self.long_term_memory) >= 2:
last_msg = self.long_term_memory[-1]
if isinstance(last_msg, AIMessage):
return self.safe_parse_json(last_msg.content)
return None
def set_system_message(self, content, local=True):
system_message = SystemMessage(content=content)
self.log(system_message, local)
self.long_term_memory.append(system_message)
async def invoke(self, content, volatile=False, local=True):
messages = self.long_term_memory.copy()
if volatile:
messages.extend(self.volatile_memory)
else:
messages.extend(self.short_term_memory)
messages.append(HumanMessage(content=content))
async def solve(self, content: str, sample_idx: int, example_idx: int):
messages = []
if self.system_message:
messages.append({"role": "system", "content": self.system_message})
user_message = {"role": "user", "content": content}
messages.append(user_message)
try:
response = await self.model_pool.invoke(messages)
self.log_tokens(messages, response)
if volatile:
self.volatile_memory.extend([HumanMessage(content=content), response])
else:
self.short_term_memory.extend([HumanMessage(content=content), response])
local_ = not volatile and local
self.log(HumanMessage(content=content), local=local_)
self.log(response, local=local_)
return response.content.strip()
except Exception as e: # pylint: disable=broad-exception-caught
print(f"[Error] Error occurred while invoking LLM: {e}")
response, logprobs = await self.model_pool.invoke(messages)
log_filename = f"llm_log/sample_{sample_idx}/example_{example_idx}.txt"
self.logger.log(
f"[USER]\n{content}\n\n[RESPONSE]\n{response}\n",
log_filename,
print_log=False,
)
return response, logprobs
except Exception as e:
self.logger.log(f"[Agent] invoke failed: {e}")
return None, None

View File

@@ -1,200 +1,81 @@
import os
import json
import datasets
import numpy as np
from glob import glob
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .embedding_index import EmbeddingIndex
from glob import glob
from typing import Optional, List
class DataLoader:
def __init__(
self,
data_path,
user_id,
selection_criteria="out_random",
num_examples=1,
embedding_index: Optional["EmbeddingIndex"] = None,
self,
data_path: str,
target_user: str,
shuffle: bool = False,
seed: int = 0,
):
self.is_valid = False
self.embedding_index = embedding_index
self.data_path = data_path
if not os.path.exists(os.path.join(data_path, "info.json")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
return
if selection_criteria in ["out_similar", "in_similar"] and embedding_index is None:
print(f"[WARNING] {selection_criteria} requires embedding_index, falling back to random")
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
self.seed = seed
self.task_metadata = self.load_task_metadata(data_path)
self.user_metadata = self.load_user_metadata(data_path)
self.target_dataset = self.load_target_dataset(data_path, target_user, shuffle)
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
self.example_dataset = datasets.Dataset.from_list([])
users = glob(os.path.join(data_path, "*"))
users = [path.split("/")[-1] for path in users]
if "info.json" in users:
users.remove("info.json")
for user in users:
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
all_users = glob(os.path.join(data_path, "*"))
all_users = [os.path.basename(p) for p in all_users if os.path.isdir(p)]
self.source_users = [u for u in all_users if u != target_user]
self.source_dataset = self.load_source_dataset(data_path, self.source_users)
self.classes = list(self.task_metadata["class"].keys())
self.test_dataset = self.test_dataset.shuffle(seed=0)
self.example_dataset = self.example_dataset.shuffle(seed=0)
def load_task_metadata(self, data_path: str):
task_metadata_path = os.path.join(data_path, "task_metadata.json")
with open(task_metadata_path, "r", encoding="utf-8") as f:
return json.load(f)
self.user_id = user_id
self.selection_criteria = selection_criteria
self.num_examples = num_examples
def load_user_metadata(self, data_path: str):
user_metadata_path = os.path.join(data_path, "user_metadata.json")
with open(user_metadata_path, "r", encoding="utf-8") as f:
return json.load(f)
self.classes = sorted(list(self.metadata["class"].keys()))
# Build lookup index for fast example retrieval: (user_id, idx) -> dataset_index
self._example_lookup = {}
for i, example in enumerate(self.example_dataset):
key = (str(example["user_id"]), int(example["idx"]))
self._example_lookup[key] = i
if selection_criteria in ["out_similar", "in_similar"]:
self.selected_examples = None
else:
self.selected_examples = self.sample_examples()
if self.selected_examples is None:
return
def load_target_dataset(
self, data_path: str, target_user: str, shuffle: bool = False
):
target_dataset_path = os.path.join(data_path, target_user)
target_dataset = datasets.load_from_disk(target_dataset_path)
if shuffle:
return target_dataset.shuffle(seed=self.seed)
return target_dataset
self.is_valid = True
def load_source_dataset(self, data_path: str, source_users: List[str]):
source_dataset = datasets.Dataset.from_list([])
for user in source_users:
user_dataset = datasets.load_from_disk(os.path.join(data_path, user))
source_dataset = datasets.concatenate_datasets(
[source_dataset, user_dataset]
)
source_dataset = source_dataset.shuffle(seed=self.seed)
return source_dataset
def __len__(self):
return len(self.test_dataset)
return len(self.target_dataset)
def __getitem__(self, idx):
sample = self.test_dataset[idx]
if self.selection_criteria in ["out_similar", "in_similar"]:
examples = self.sample_similar_examples(sample)
else:
examples = self.selected_examples
return sample, examples
def __getitem__(self, idx: int):
return self.target_dataset[idx]
def __iter__(self):
for sample in self.test_dataset:
if self.selection_criteria in ["out_similar", "in_similar"]:
examples = self.sample_similar_examples(sample)
else:
examples = self.selected_examples
yield sample, examples
def sample_similar_examples(self, sample):
"""
Sample examples based on embedding similarity to the given sample.
For each class, finds the most similar example from the example dataset
using Chronos-2 embeddings.
Args:
sample: The test sample to find similar examples for
Returns:
HuggingFace Dataset containing similar examples (one per class)
"""
if self.embedding_index is None:
return self.sample_examples()
query_embedding = self.embedding_index.get_embedding_by_key(
user_id=str(sample["user_id"]),
session_id=str(sample["session_id"]),
idx=int(sample["idx"]),
)
if query_embedding is None:
print(f"[WARNING] No embedding found for sample {sample['user_id']}/{sample['session_id']}/{sample['idx']}")
return self.sample_examples()
if self.selection_criteria == "out_similar":
exclude_user = str(self.user_id)
include_user = None
else:
exclude_user = None
include_user = str(self.user_id)
similar_per_class = self.embedding_index.find_similar_per_class(
query_embedding=query_embedding,
classes=self.classes,
k_per_class=self.num_examples,
exclude_user=exclude_user,
include_user=include_user,
filter_session="1",
)
example_list = []
for cls, similar_samples in similar_per_class.items():
for global_idx, similarity, metadata in similar_samples:
example = self._find_example_by_metadata(metadata)
if example is not None:
example_list.append(example)
if len(example_list) == 0:
print(f"[WARNING] No similar examples found, falling back to random")
return self.sample_examples()
return datasets.Dataset.from_list(example_list)
def _find_example_by_metadata(self, metadata):
"""Find an example in the example_dataset by its metadata using O(1) lookup."""
user_id = str(metadata["user_id"])
idx = int(metadata["idx"])
key = (user_id, idx)
if key not in self._example_lookup:
return None
dataset_idx = self._example_lookup[key]
example = self.example_dataset[dataset_idx]
return {
"user_id": example["user_id"],
"session_id": example["session_id"],
"idx": example["idx"],
"label": example["label"],
"features": example["features"],
"data": example.get("data", {}),
}
yield from self.target_dataset
def sample_examples(self):
example_dataset = datasets.Dataset.from_list([])
if self.selection_criteria == "out_random":
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] != user_id)
for c in self.classes:
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
if len(class_dataset) < self.num_examples:
return None
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
elif self.selection_criteria == "in_random":
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] == user_id)
for c in self.classes:
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
if len(class_dataset) < self.num_examples:
return None
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
return example_dataset
def get_source_dataset(self):
return self.source_dataset
def get_metadata(self):
return self.metadata
def get_task_metadata(self):
return self.task_metadata
def get_sensor_info(self):
return self.metadata["feature"]
def get_user_metadata(self, user: Optional[str] = None):
if user is None:
return self.user_metadata
return self.user_metadata[user]
def get_task_info(self):
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
classes_info = "\n".join(classes_info)
task_info += f"**Classes**:\n{classes_info}"
return task_info
def get_source_users(self):
return self.source_users
def get_classes_info(self):
classes_info = [k for k in self.metadata["class"].keys()]
return classes_info
def get_classes(self):
return self.classes

View File

@@ -1,281 +0,0 @@
"""
Embedding Index for Similarity-based Example Selection
This module provides functionality to:
1. Load pre-computed Chronos-2 embeddings
2. Build an index for fast nearest neighbor search
3. Find similar examples based on embedding distance
The embedding index enables ICL (In-Context Learning) example selection
based on semantic similarity rather than random sampling.
Similarity Strategies:
- out_similar: Find similar examples from OTHER users (cross-user transfer)
- in_similar: Find similar examples from SAME user (personalization)
Usage:
index = EmbeddingIndex(embedding_path)
similar_indices = index.find_similar(query_embedding, k=5, exclude_user="01")
Author: NMSL Research Team
Date: 2026-01-16
"""
import os
import numpy as np
from typing import List, Tuple, Optional, Dict, Any
from datasets import load_from_disk, Dataset
class EmbeddingIndex:
"""
Index for fast similarity search over pre-computed embeddings.
Uses cosine similarity for finding nearest neighbors in embedding space.
Supports filtering by user_id and session_id for controlled experiments.
Attributes:
embeddings: numpy array of shape (num_samples, embedding_dim)
user_ids: list of user identifiers for each sample
session_ids: list of session identifiers for each sample
labels: list of sleep stage labels for each sample
indices: list of original indices in the dataset
"""
def __init__(self, embedding_path: str):
"""
Initialize the embedding index from a HuggingFace dataset.
Args:
embedding_path: Path to directory containing saved embeddings dataset
(output from gen_plot.py extract command)
"""
if not os.path.isdir(embedding_path):
raise FileNotFoundError(
f"Embedding directory not found: {embedding_path}. "
"Run 'python gen_plot.py extract' first to generate embeddings."
)
print(f"[EmbeddingIndex] Loading embeddings from: {embedding_path}")
self.dataset = load_from_disk(embedding_path)
# Extract arrays for fast access
self.embeddings = np.array(self.dataset["embedding"], dtype=np.float32)
self.user_ids = np.array([str(uid) for uid in self.dataset["user_id"]])
self.session_ids = np.array([str(sid) for sid in self.dataset["session_id"]])
self.labels = np.array([str(label) for label in self.dataset["label"]])
self.indices = np.array(self.dataset["idx"])
# Normalize embeddings for cosine similarity (dot product of unit vectors)
self.embeddings_normalized = self._normalize(self.embeddings)
# Build lookup indices for fast filtering
self._build_lookup_indices()
print(f"[EmbeddingIndex] Loaded {len(self.embeddings)} samples")
print(f"[EmbeddingIndex] Embedding dimension: {self.embeddings.shape[1]}")
print(f"[EmbeddingIndex] Unique users: {len(np.unique(self.user_ids))}")
print(f"[EmbeddingIndex] Unique sessions: {len(np.unique(self.session_ids))}")
def _normalize(self, vectors: np.ndarray) -> np.ndarray:
"""L2 normalize vectors for cosine similarity computation."""
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms = np.where(norms == 0, 1, norms) # Avoid division by zero
return vectors / norms
def _build_lookup_indices(self):
"""Build dictionaries for fast filtering by user/session/label."""
self.user_to_indices: Dict[str, np.ndarray] = {}
self.session_to_indices: Dict[str, np.ndarray] = {}
self.label_to_indices: Dict[str, np.ndarray] = {}
for user_id in np.unique(self.user_ids):
self.user_to_indices[user_id] = np.where(self.user_ids == user_id)[0]
for session_id in np.unique(self.session_ids):
self.session_to_indices[session_id] = np.where(self.session_ids == session_id)[0]
for label in np.unique(self.labels):
self.label_to_indices[label] = np.where(self.labels == label)[0]
def get_embedding_by_key(
self,
user_id: str,
session_id: str,
idx: int
) -> Optional[np.ndarray]:
"""
Get embedding for a specific sample identified by (user_id, session_id, idx).
Args:
user_id: User identifier
session_id: Session identifier (1 or 2)
idx: Sample index within the session
Returns:
Embedding vector or None if not found
"""
mask = (
(self.user_ids == str(user_id)) &
(self.session_ids == str(session_id)) &
(self.indices == idx)
)
matches = np.where(mask)[0]
if len(matches) == 0:
return None
return self.embeddings[matches[0]]
def cosine_similarity(
self,
query: np.ndarray,
candidates: np.ndarray
) -> np.ndarray:
"""
Compute cosine similarity between query and candidate vectors.
Args:
query: Query vector of shape (embedding_dim,)
candidates: Candidate matrix of shape (num_candidates, embedding_dim)
Returns:
Similarity scores of shape (num_candidates,)
"""
query_normalized = query / (np.linalg.norm(query) + 1e-8)
return candidates @ query_normalized
def find_similar(
self,
query_embedding: np.ndarray,
k: int = 5,
exclude_user: Optional[str] = None,
include_user: Optional[str] = None,
filter_session: Optional[str] = None,
filter_label: Optional[str] = None,
) -> List[Tuple[int, float, Dict[str, Any]]]:
"""
Find k most similar samples to the query embedding.
Args:
query_embedding: Query vector of shape (embedding_dim,)
k: Number of nearest neighbors to return
exclude_user: User ID to exclude from search (for out_similar)
include_user: User ID to include only (for in_similar)
filter_session: Only search in this session (e.g., "1" for train set)
filter_label: Only return samples with this label
Returns:
List of (index, similarity, metadata) tuples sorted by similarity (descending)
metadata contains: user_id, session_id, idx, label
"""
# Build candidate mask based on filters
candidate_mask = np.ones(len(self.embeddings), dtype=bool)
if exclude_user is not None:
candidate_mask &= (self.user_ids != str(exclude_user))
if include_user is not None:
candidate_mask &= (self.user_ids == str(include_user))
if filter_session is not None:
candidate_mask &= (self.session_ids == str(filter_session))
if filter_label is not None:
candidate_mask &= (self.labels == str(filter_label))
candidate_indices = np.where(candidate_mask)[0]
if len(candidate_indices) == 0:
return []
# Compute similarities
candidate_embeddings = self.embeddings_normalized[candidate_indices]
similarities = self.cosine_similarity(query_embedding, candidate_embeddings)
# Get top-k
k = min(k, len(similarities))
top_k_local = np.argsort(similarities)[::-1][:k]
results = []
for local_idx in top_k_local:
global_idx = candidate_indices[local_idx]
sim = similarities[local_idx]
metadata = {
"user_id": self.user_ids[global_idx],
"session_id": self.session_ids[global_idx],
"idx": int(self.indices[global_idx]),
"label": self.labels[global_idx],
}
results.append((global_idx, float(sim), metadata))
return results
def find_similar_per_class(
self,
query_embedding: np.ndarray,
classes: List[str],
k_per_class: int = 1,
exclude_user: Optional[str] = None,
include_user: Optional[str] = None,
filter_session: Optional[str] = "1",
) -> Dict[str, List[Tuple[int, float, Dict[str, Any]]]]:
"""
Find k most similar samples for each class.
This ensures balanced class representation in ICL examples,
similar to the original random sampling approach.
Args:
query_embedding: Query vector
classes: List of class labels to search
k_per_class: Number of examples per class
exclude_user: User to exclude (out_similar mode)
include_user: User to include only (in_similar mode)
filter_session: Session to search in (default: "1" = train set)
Returns:
Dictionary mapping class label to list of similar samples
"""
results = {}
for cls in classes:
similar = self.find_similar(
query_embedding=query_embedding,
k=k_per_class,
exclude_user=exclude_user,
include_user=include_user,
filter_session=filter_session,
filter_label=cls,
)
results[cls] = similar
return results
def get_sample_metadata(self, global_idx: int) -> Dict[str, Any]:
"""Get metadata for a sample by its global index."""
return {
"user_id": self.user_ids[global_idx],
"session_id": self.session_ids[global_idx],
"idx": int(self.indices[global_idx]),
"label": self.labels[global_idx],
"embedding": self.embeddings[global_idx],
}
def create_embedding_index(embedding_path: str) -> Optional[EmbeddingIndex]:
"""
Factory function to create an EmbeddingIndex, returning None if path doesn't exist.
Args:
embedding_path: Path to embeddings directory
Returns:
EmbeddingIndex instance or None
"""
if not os.path.isdir(embedding_path):
print(f"[WARNING] Embedding path not found: {embedding_path}")
return None
try:
return EmbeddingIndex(embedding_path)
except Exception as e:
print(f"[ERROR] Failed to load embeddings: {e}")
return None

37
core/example_queue.py Normal file
View File

@@ -0,0 +1,37 @@
from typing import List, Dict, Any
from core.logger import Logger
class ExampleQueue:
def __init__(self, queue_size: int, logger: Logger):
self.queue_size = queue_size
self.logger = logger
self.queue: List[Dict[str, Any]] = []
def __iter__(self):
yield from self.queue
def update(self, results: List[Dict[str, Any]], new_examples: List[Dict[str, Any]]):
n_replace = len(new_examples)
if n_replace == 0:
return
if results:
ranked = sorted(range(len(results)), key=lambda i: results[i]["score"])
drop_indices = sorted(ranked[:n_replace], reverse=True)
for idx in drop_indices:
self.logger.log(
f"[Queue] Dropping index {idx} (score={results[idx]['score']:.4f})"
)
self.queue.pop(idx)
slots = self.queue_size - len(self.queue)
if slots < len(new_examples):
self.logger.log(
f"[Queue] Capping: {len(new_examples)} new but only {slots} slot(s) free"
)
new_examples = new_examples[:slots]
self.queue.extend(new_examples)
self.logger.log(f"[Queue] Updated, queue size = {len(self.queue)}")

29
core/json_utils.py Normal file
View File

@@ -0,0 +1,29 @@
import re
import json
def clean_json_text(text: str):
text = text.strip()
text = text.replace("\u2018", "'").replace("\u2019", "'") # smart single quotes
text = text.replace("\u201c", '"').replace("\u201d", '"') # smart double quotes
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text) # escape stray backslashes
text = re.sub(r",\s*}", "}", text) # trailing comma in object
text = re.sub(r",\s*]", "]", text) # trailing comma in array
text = "".join(ch for ch in text if ch.isprintable()) # strip control chars
text = text.replace("][", ",") # merge adjacent arrays
return text
def safe_parse_json(text: str):
if not text:
return None
text = text.strip()
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match and not text.endswith("}"):
match = re.search(r"\{.*\}", text + "}", re.DOTALL)
if match:
try:
return json.loads(clean_json_text(match.group(0)))
except json.JSONDecodeError as e:
return None
return None

58
core/logger.py Normal file
View File

@@ -0,0 +1,58 @@
import os
import yaml
from datetime import datetime
from sklearn.metrics import accuracy_score, f1_score
class Logger:
def __init__(self, log_path: str):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.log_path = f"{log_path}_{timestamp}"
os.makedirs(self.log_path, exist_ok=True)
self.answers = []
def log(self, message: str, filename: str = "log.txt", print_log: bool = True):
if print_log:
print(message)
log_file_path = os.path.join(self.log_path, filename)
base_dir = os.path.dirname(log_file_path)
os.makedirs(base_dir, exist_ok=True)
with open(log_file_path, "a", encoding="utf-8") as f:
f.write(message + "\n")
def log_config(self, config: dict):
message = yaml.dump(config, default_flow_style=False, sort_keys=False).strip()
self.log(message, "config.yaml", print_log=False)
def log_result(self, idx: int, answer: str, ground_truth: str):
self.answers.append(
{
"idx": idx,
"answer": answer,
"ground_truth": ground_truth,
}
)
self.log(
f"[RESULT] {idx}: answer={answer}, ground_truth={ground_truth}",
"result.txt",
)
def report(self, elapsed_seconds: float = None):
if not self.answers:
self.log("[REPORT] No valid answers recorded", "report.txt")
return
answers = [a["answer"] for a in self.answers]
ground_truths = [a["ground_truth"] for a in self.answers]
accuracy = accuracy_score(ground_truths, answers)
f1 = f1_score(ground_truths, answers, average="macro")
n = len(self.answers)
time_str = ""
if elapsed_seconds is not None:
m, s = divmod(int(elapsed_seconds), 60)
h, m = divmod(m, 60)
time_str = f", time={h}h{m:02d}m{s:02d}s"
self.log(
f"[REPORT] accuracy={accuracy:.4f}, f1={f1:.4f}, n={n}{time_str}",
"report.txt",
)

View File

@@ -1,95 +1,133 @@
import asyncio
import requests
import numpy as np
from langchain_ollama import ChatOllama
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage
def load_models(models):
model_pool = AsyncModelPool()
for model in models:
model_pool.add_model(Model(model))
model_pool.init_models()
return model_pool
from typing import Dict, List
class Model:
def __init__(self, model, temperature=0.7):
if model.startswith("ollama:"):
model = model.replace("ollama:", "")
if "url:" in model:
model = model.replace("url:", "")
base_url = model.split("/")[0]
if not base_url.startswith("http"):
base_url = "http://" + base_url
model_type = model.split("/")[1]
self.model = ChatOllama(
model=model_type,
base_url=base_url,
temperature=temperature,
num_ctx=12000,
)
else:
self.model = ChatOllama(
model=model.replace("ollama:", ""),
temperature=temperature,
num_ctx=12000,
)
else:
self.model = init_chat_model(
model=model,
temperature=temperature,
)
def __init__(
self,
model_path: str,
temperature: float = 0.0,
num_ctx: int = 131072,
max_tokens: int = -1,
logprobs: bool = True,
top_logprobs: int = 20,
top_p: float = 1.0,
top_k: int = 0,
stream: bool = False,
think: str = "low",
):
self.backend = None
def invoke(self, messages):
if model_path.startswith("ollama:"):
raw = model_path.split("url:")[1]
self.backend = "ollama"
self.base_url = f"http://{raw.split('/')[0]}"
self.model_name = raw.split("/")[1]
self.temperature = temperature
self.num_ctx = num_ctx
self.max_tokens = max_tokens
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.top_p = top_p
self.top_k = top_k
self.stream = stream
self.think = think
else:
raise ValueError(f"Unknown model prefix: {model_path}")
def invoke(self, messages: List[Dict[str, str]]):
try:
response = self.model.invoke(messages)
return response
return self._invoke_ollama(messages)
except Exception as e:
print(f"[Error] Error occurred while invoking LLM: {e}")
return e
print(f"[Error] invoke failed: {e}")
return "", np.array([])
def _invoke_ollama(self, messages: List[Dict[str, str]]):
resp = requests.post(
f"{self.base_url}/api/chat",
json={
"messages": messages,
"model": self.model_name,
"temperature": self.temperature,
"num_ctx": self.num_ctx,
"num_predict": self.max_tokens,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"top_p": self.top_p,
"top_k": self.top_k,
"stream": self.stream,
"think": self.think,
},
timeout=300,
)
resp.raise_for_status()
data = resp.json()
response = data["message"]["content"]
logprobs = []
for lp in data.get("logprobs", []):
logprobs.append([tp["logprob"] for tp in lp.get("top_logprobs", [])])
return response, np.array(logprobs) if logprobs else np.array([])
class AsyncModel:
def __init__(self, model):
def __init__(self, model: Model):
self.model = model
async def invoke(self, content):
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
async def invoke(self, messages: List[Dict[str, str]]):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None,
lambda: self.model.invoke(content),
lambda: self.model.invoke(messages),
)
return response
class AsyncModelPool:
def __init__(self):
self.models = []
self._available_models = None
self._model_semaphore = None
self._queue: asyncio.Queue = asyncio.Queue()
def add_model(self, model):
self.models.append(model)
def add_model(self, model: Model):
self._queue.put_nowait(AsyncModel(model))
def init_models(self):
# Initialize the queue and semaphore in the current event loop
self._available_models = asyncio.Queue()
for model in self.models:
async_model = AsyncModel(model)
self._available_models.put_nowait(async_model)
self._model_semaphore = asyncio.Semaphore(len(self.models))
def warmup(self):
for model in self.models:
model.invoke([HumanMessage(content="Hello world!")])
async def invoke(self, content):
if self._available_models is None:
raise RuntimeError("Model pool not initialized. Call init_models() first.")
async_model = await self._available_models.get()
async def invoke(self, messages: List[Dict[str, str]]):
async_model = await self._queue.get()
try:
response = await async_model.invoke(content)
return response
return await async_model.invoke(messages)
finally:
self._available_models.put_nowait(async_model)
self._queue.put_nowait(async_model)
def load_models(
model_paths: List[str],
temperature: float = 0.0,
num_ctx: int = 131072,
max_tokens: int = -1,
logprobs: bool = True,
top_logprobs: int = 20,
top_p: float = 1.0,
top_k: int = 0,
stream: bool = False,
think: str = "low",
):
pool = AsyncModelPool()
for path in model_paths:
pool.add_model(
Model(
path,
temperature=temperature,
num_ctx=num_ctx,
max_tokens=max_tokens,
logprobs=logprobs,
top_logprobs=top_logprobs,
top_p=top_p,
top_k=top_k,
stream=stream,
think=think,
)
)
print(f"[ModelPool] Loaded a model from {path}")
return pool

66
core/prompt.py Normal file
View File

@@ -0,0 +1,66 @@
from typing import Any, Dict, List
def gen_system_message(metadata: Dict[str, Any]):
task_info = metadata["task"]
classes_info = [f" - {k}: {v}" for k, v in metadata["class"].items()]
classes_info = "\n".join(classes_info)
data_info = metadata["data"]
feature_info = metadata["feature"]
system_message = (
f"You are an assistant who interprets sensor data to solve a task.\n\n"
f"1. Task:\n"
f"{task_info}\n\n"
f"2. Classes:\n"
f"{classes_info}\n\n"
f"3. Data:\n"
f"{data_info}\n\n"
f"4. Features:\n"
f"{feature_info}\n\n"
"Your goal is to analyze the sensor data and "
"provide a reasoned answer for the task.\n"
"Do not output analysis."
)
return system_message
def gen_task_message(
sample: Dict[str, Any],
example_set: List[Dict[str, Any]],
):
def format_feature(value: Any):
if isinstance(value, float):
if abs(value) >= 1e4 or (abs(value) < 1e-2 and value != 0):
return f"{value:.2e}"
return f"{value:.2f}"
return str(value)
example_info = ""
for cls, examples in example_set.items():
for i, example in enumerate(examples):
example_info += f"Example {i+1} of {cls}:\n"
for k, v in example["features"].items():
example_info += f" - {k}: {format_feature(v)}\n"
example_info += "\n"
example_info = example_info.strip()
test_info = "Current sensor data:\n"
for k, v in sample["features"].items():
test_info += f" - {k}: {format_feature(v)}\n"
test_info = test_info.strip()
classes = list(example_set.keys())
task_message = (
"You have a few labeled examples of sensor data:\n"
f"{example_info}\n\n"
f"And you have the current sensor data:\n"
f"{test_info}\n\n"
f"Please provide your answer among {classes} "
"and the reasoning for your answer.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<reasoning for the answer>",\n'
f' "ANSWER": "<answer among {classes}>"\n'
"}"
)
return task_message

52
core/recruiter.py Normal file
View File

@@ -0,0 +1,52 @@
import random
import numpy as np
from datasets import Dataset
from typing import Any, Dict, List
from core.logger import Logger
class Recruiter:
def __init__(
self,
source_dataset: Dataset,
source_users: List[str],
classes: List[str],
num_shot: int,
logger: Logger,
):
self.source_dataset = source_dataset
self.num_shot = num_shot
self.source_users = source_users
self.classes = classes
self.logger = logger
def recruit(self, num_example_set: int):
# place holder for the recruiter strategy
self.logger.log(f"[Recruiter] Recruiting {num_example_set} example set(s)")
# randomly select examples from the user dataset
list_example_sets = []
for _ in range(num_example_set):
# randomly select user
user = random.choice(self.source_users)
user_dataset = self.source_dataset.filter(lambda x: x["user_id"] == user)
example_set = {}
for cls in self.classes:
cls_dataset = user_dataset.filter(lambda x: x["label"] == cls)
if len(cls_dataset) < self.num_shot:
raise ValueError(
f"Not enough examples for class {cls} in user {user}"
)
random_index = np.random.choice(
len(cls_dataset), self.num_shot, replace=False
)
example_set[cls] = cls_dataset.select(random_index)
self.logger.log(
f"[Recruiter] Recruited {self.num_shot} example(s) from user {user}"
)
list_example_sets.append(example_set)
return list_example_sets
def update_strategy(self, results: List[Dict[str, Any]]):
# TODO: implement adaptive recruitment strategy based on results
pass

27
core/scores.py Normal file
View File

@@ -0,0 +1,27 @@
import math
import numpy as np
def self_certainty(logprobs: np.ndarray, vocab_size: int) -> float:
"""Returns -inf if logprobs is None or empty."""
if logprobs is None or logprobs.size == 0:
return float("-inf")
# probability mass in top-k
probs = np.exp(logprobs)
k = probs.shape[-1]
topk_sum = probs.sum(axis=-1)
# remaining probability mass
tail_mass = 1.0 - topk_sum
tail_mass = np.clip(tail_mass, 1e-12, None)
# uniform distribution over remaining tokens
tail_prob = tail_mass / (vocab_size - k)
# sum log probabilities
logprob_sum_topk = logprobs.sum(axis=-1)
logprob_sum_tail = (vocab_size - k) * np.log(tail_prob)
logprob_sum = logprob_sum_topk + logprob_sum_tail
# self-certainty score
score = (-1.0 / vocab_size) * logprob_sum - math.log(vocab_size)
return float(np.mean(score))

View File

@@ -1,181 +0,0 @@
import json
import copy
import os
from .agent import Agent
class SensingAgent(Agent):
def __init__(
self,
name,
model_pool,
task_info,
classes_info,
sensor_info,
sample,
examples,
log_path,
):
super().__init__(
name=name,
model_pool=model_pool,
log_path=log_path,
)
self.task_info = task_info
self.classes_info = classes_info
self.sensor_info = sensor_info
self.sample = sample
self.examples = examples
self.init_system_message()
def init_system_message(self):
content = (
f"You are {self.name} agent that interprets sensor data to solve a task.\n"
"You have the following information about the task:\n"
f"{self.task_info}\n\n"
"You have the following information about the sensor data:\n"
f"{self.sensor_info}\n\n"
"Your goal is to analyze the features and "
"provide a reasoned answer using your knowledge."
)
self.set_system_message(content)
def gen_feature_info(self):
feature_info = f"{self.name} features:\n"
if len(self.examples) > 0:
feature_info += f"{self.gen_example_info()}\n\n"
feature_info += "**Current sample features**:\n"
for k, v in self.sample["features"].items():
feature_info += f" - {k}: {self.format_feature(v)}\n"
feature_info = feature_info.strip()
return feature_info
def gen_example_info(self):
example_info = (
"**Examples**\n"
"Sensor values might not always align with your inherent "
"knowledge due to differences in data collection or processing. "
"So, we included a few labeled examples to help your interpretation:\n"
)
for example in self.examples:
example_info += f"*Example of {example['label']}*:\n"
for k, v in example["features"].items():
example_info += f" - {k}: {self.format_feature(v)}\n"
example_info += "\n"
example_info = example_info.strip()
return example_info
def format_feature(self, value):
if isinstance(value, float):
if abs(value) >= 1e4 or abs(value) < 1e-2:
return f"{value:.2e}"
return f"{value:.2f}"
return value
def log_summary(self, message, print_log=True):
path = os.path.join(self.log_path, "summary.txt")
with open(path, "a", encoding="utf-8") as f:
f.write(f"{message}\n")
if print_log:
print(message)
async def solve(self, sample, examples, ground_truth):
self.sample = sample
self.examples = examples
feature_info = self.gen_feature_info()
content = (
f"You have received sensor features from {self.name} modality:\n"
f"{feature_info}\n\n"
f"Please provide your answer for the task among {self.classes_info} "
"and the reasoning for your answer.\n"
"Note that the sensor features might be wrong due to the data collection or processing.\n"
"You can evaluate the quality of the features by checking the examples you have.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Reasoning for the answer>",\n'
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "ANSWER"]
)
self.clean_short_term_memory()
self.clean_long_term_memory()
answer = parsed_response["ANSWER"]
self.log_summary(f"Answer: {answer} (Ground truth: {ground_truth})", print_log=True)
return parsed_response
async def interpret(self):
feature_info = self.gen_feature_info()
content = (
f"You have received sensor features from {self.name} modality:\n"
f"{feature_info}\n\n"
f"Please provide your answer for the task among {self.classes_info} "
"and the reasoning for your answer.\n"
"Note that the sensor features might be wrong due to the data collection or processing.\n"
"You can evaluate the quality of the features by checking the examples you have.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Reasoning for the answer>",\n'
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "ANSWER"]
)
return parsed_response
async def evaluate(self, target_name, initial_response):
initial_response_info = json.dumps(initial_response, indent=2)
content = (
f"Other agent, <{target_name}> provided the following answer for the same task:\n"
f"{initial_response_info}\n\n"
"Please evaluate the given reasoning and answer based on your judgement. "
"You may either support with it or disagree.\n"
"If you agree, explain why the reasoning and answer are valid. "
"If you disagree, explain why the reasoning or answer may be flawed, "
f"and provide constructive feedback on how <{target_name}> can improve its response.\n"
"Respond in the following strict JSON format:\n"
"{\n"
f' "EVALUATION": "<Evaluation to <{target_name}>"\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content, volatile=True)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["EVALUATION"], volatile=True
)
self.clean_volatile_memory()
return parsed_response
async def reflect(self, evaluations):
evaluations_info = json.dumps(evaluations, indent=2)
content = (
f"Other agents have evaluated your answer for the same task:\n"
f"{evaluations_info}\n\n"
"Please reflect on the evaluations and provide a refined answer for the same task.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Reasoning for the answer>",\n'
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "ANSWER"]
)
return parsed_response

24
core/vote.py Normal file
View File

@@ -0,0 +1,24 @@
from typing import Any, Dict, List
from collections import Counter
def borda_vote(results: List[Dict[str, Any]], borda_p: float = 1.0):
parsed = [
{"answer": r["answer"], "score": r["score"]}
for r in results
if r.get("answer") is not None
]
if not parsed:
return None, {}
parsed.sort(key=lambda x: x["score"], reverse=True)
n = len(parsed)
tally: Counter = Counter()
for rank, entry in enumerate(parsed, start=1):
votes = int((n - rank + 1) ** borda_p)
tally[entry["answer"]] += votes
winner, _ = tally.most_common(1)[0]
return winner, dict(tally.most_common())

View File

@@ -0,0 +1,78 @@
"""
Generate per-(method, user) config files.
Usage:
python experiments/gen_configs.py \
--dataset sleepedf \
--data_path /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new \
--users 00,01,02
"""
import os
import yaml
from fire import Fire
METHODS = [
"random_fixed_single",
"random_dynamic_single",
"random_fixed_sc",
"random_dynamic_sc",
"random_dynamic_borda",
"random_fixed_borda",
"ours",
]
DEFAULTS = {
"queue_size": 5,
"num_shot": 1,
"update_size": 3,
"borda_p": 1.0,
"vocab_size": 200064,
"model_paths": [
"ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b",
],
}
def main(dataset: str, data_path: str, users: str):
"""
Args:
dataset: Dataset name (e.g. "sleepedf").
data_path: Absolute path to the dataset.
users: Comma-separated user IDs (e.g. "00,01,02").
"""
users = [u.strip() for u in users.split(",")]
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
for method in METHODS:
for user in users:
config_dir = os.path.join(project_root, "config", method, dataset)
os.makedirs(config_dir, exist_ok=True)
config = dict(DEFAULTS)
config["data_path"] = data_path
config["target_user"] = user
config["log_path"] = (
f"/mnt/sting/hjyoon/projects/llm_personalization/logs"
f"/{method}/{dataset}/{user}"
)
config_path = os.path.join(config_dir, f"user{user}.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
print(f" {config_path}")
print(f"\nGenerated {len(METHODS) * len(users)} configs.")
if __name__ == "__main__":
Fire(main)

75
experiments/run.sh Executable file
View File

@@ -0,0 +1,75 @@
#!/bin/bash
#
# Generic experiment runner. Finds configs under config/<method>/<dataset>/
# and runs the corresponding Python script.
#
# Usage:
# bash experiments/run.sh <method> <dataset> # one method, all users
# bash experiments/run.sh <method> <dataset> <config> # one config file
# bash experiments/run.sh all <dataset> # all methods
#
# Examples:
# bash experiments/run.sh ours sleepedf
# bash experiments/run.sh random_fixed_sc sleepedf
# bash experiments/run.sh random_fixed_sc sleepedf user00.yaml
# bash experiments/run.sh all sleepedf
#
set -e
METHOD="${1:?Usage: $0 <method|all> <dataset> [config_file]}"
DATASET="${2:?Usage: $0 <method|all> <dataset> [config_file]}"
CONFIG_FILE="$3"
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
CONFIG_ROOT="$PROJECT_ROOT/config"
run_config() {
local method="$1"
local config="$2"
echo "============================================"
echo " Method: $method | Config: $(basename "$config")"
echo "============================================"
if [ "$method" = "ours" ]; then
python "$PROJECT_ROOT/run.py" --config_path "$config"
else
python "$PROJECT_ROOT/baselines/${method}.py" --config_path "$config"
fi
echo ""
}
if [ "$METHOD" != "all" ] && [ -n "$CONFIG_FILE" ]; then
# Single config
config="$CONFIG_ROOT/$METHOD/$DATASET/$CONFIG_FILE"
if [ ! -f "$config" ]; then
echo "Error: config not found: $config"
exit 1
fi
run_config "$METHOD" "$config"
elif [ "$METHOD" != "all" ]; then
# All configs for one method
config_dir="$CONFIG_ROOT/$METHOD/$DATASET"
if [ ! -d "$config_dir" ]; then
echo "Error: config directory not found: $config_dir"
exit 1
fi
for config in "$config_dir"/*.yaml; do
run_config "$METHOD" "$config"
done
else
# All methods
for method_dir in "$CONFIG_ROOT"/*/; do
method="$(basename "$method_dir")"
dataset_dir="$method_dir$DATASET"
[ -d "$dataset_dir" ] || continue
for config in "$dataset_dir"/*.yaml; do
run_config "$method" "$config"
done
done
fi
echo "All experiments complete."

View File

@@ -0,0 +1,324 @@
import os
import json
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
from fire import Fire
from datasets import Dataset
from datetime import timedelta
warnings.filterwarnings("ignore")
GLOBEM_PATH = "/mnt/sting/hjyoon/projects/llm_personalization/dataset/GLOBEM/physionet.org/files/globem/1.1"
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/GLOBEM"
PHASES = {
"INS-W": [1, 2, 3, 4],
}
FEATURE_TYPES = ["f_loc", "f_screen", "f_slp", "f_steps"]
FEATURE_COLUMNS = {
"f_loc": [
"phone_locations_barnett_avgflightdur", "phone_locations_barnett_avgflightlen",
"phone_locations_barnett_circdnrtn", "phone_locations_barnett_disttravelled",
"phone_locations_barnett_hometime", "phone_locations_barnett_maxdiam",
"phone_locations_barnett_maxhomedist", "phone_locations_barnett_probpause",
"phone_locations_barnett_rog", "phone_locations_barnett_siglocentropy",
"phone_locations_barnett_siglocsvisited", "phone_locations_barnett_stdflightdur",
"phone_locations_barnett_stdflightlen", "phone_locations_barnett_wkenddayrtn",
"phone_locations_doryab_avglengthstayatclusters", "phone_locations_doryab_avgspeed",
"phone_locations_doryab_homelabel", "phone_locations_doryab_locationentropy",
"phone_locations_doryab_locationvariance", "phone_locations_doryab_loglocationvariance",
"phone_locations_doryab_maxlengthstayatclusters", "phone_locations_doryab_minlengthstayatclusters",
"phone_locations_doryab_movingtostaticratio", "phone_locations_doryab_normalizedlocationentropy",
"phone_locations_doryab_numberlocationtransitions", "phone_locations_doryab_numberofsignificantplaces",
"phone_locations_doryab_outlierstimepercent", "phone_locations_doryab_radiusgyration",
"phone_locations_doryab_stdlengthstayatclusters", "phone_locations_doryab_timeathome",
"phone_locations_doryab_timeattop1location", "phone_locations_doryab_timeattop2location",
"phone_locations_doryab_timeattop3location", "phone_locations_doryab_totaldistance",
"phone_locations_doryab_varspeed",
"phone_locations_locmap_duration_in_locmap_study", "phone_locations_locmap_percent_in_locmap_study",
"phone_locations_locmap_duration_in_locmap_exercise", "phone_locations_locmap_percent_in_locmap_exercise",
"phone_locations_locmap_duration_in_locmap_greens", "phone_locations_locmap_percent_in_locmap_greens",
],
"f_screen": [
"phone_screen_rapids_countepisodeunlock", "phone_screen_rapids_sumdurationunlock",
"phone_screen_rapids_maxdurationunlock", "phone_screen_rapids_mindurationunlock",
"phone_screen_rapids_avgdurationunlock", "phone_screen_rapids_stddurationunlock",
"phone_screen_rapids_firstuseafter00unlock",
"phone_screen_rapids_countepisodeunlock_locmap_exercise", "phone_screen_rapids_sumdurationunlock_locmap_exercise",
"phone_screen_rapids_maxdurationunlock_locmap_exercise", "phone_screen_rapids_mindurationunlock_locmap_exercise",
"phone_screen_rapids_avgdurationunlock_locmap_exercise", "phone_screen_rapids_stddurationunlock_locmap_exercise",
"phone_screen_rapids_firstuseafter00unlock_locmap_exercise",
"phone_screen_rapids_countepisodeunlock_locmap_greens", "phone_screen_rapids_sumdurationunlock_locmap_greens",
"phone_screen_rapids_maxdurationunlock_locmap_greens", "phone_screen_rapids_mindurationunlock_locmap_greens",
"phone_screen_rapids_avgdurationunlock_locmap_greens", "phone_screen_rapids_stddurationunlock_locmap_greens",
"phone_screen_rapids_firstuseafter00unlock_locmap_greens",
"phone_screen_rapids_countepisodeunlock_locmap_living", "phone_screen_rapids_sumdurationunlock_locmap_living",
"phone_screen_rapids_maxdurationunlock_locmap_living", "phone_screen_rapids_mindurationunlock_locmap_living",
"phone_screen_rapids_avgdurationunlock_locmap_living", "phone_screen_rapids_stddurationunlock_locmap_living",
"phone_screen_rapids_firstuseafter00unlock_locmap_living",
"phone_screen_rapids_countepisodeunlock_locmap_study", "phone_screen_rapids_sumdurationunlock_locmap_study",
"phone_screen_rapids_maxdurationunlock_locmap_study", "phone_screen_rapids_mindurationunlock_locmap_study",
"phone_screen_rapids_avgdurationunlock_locmap_study", "phone_screen_rapids_stddurationunlock_locmap_study",
"phone_screen_rapids_firstuseafter00unlock_locmap_study",
"phone_screen_rapids_countepisodeunlock_locmap_home", "phone_screen_rapids_sumdurationunlock_locmap_home",
"phone_screen_rapids_maxdurationunlock_locmap_home", "phone_screen_rapids_mindurationunlock_locmap_home",
"phone_screen_rapids_avgdurationunlock_locmap_home", "phone_screen_rapids_stddurationunlock_locmap_home",
"phone_screen_rapids_firstuseafter00unlock_locmap_home",
],
"f_slp": [
"fitbit_sleep_summary_rapids_sumdurationafterwakeupmain", "fitbit_sleep_summary_rapids_sumdurationasleepmain",
"fitbit_sleep_summary_rapids_sumdurationawakemain", "fitbit_sleep_summary_rapids_sumdurationtofallasleepmain",
"fitbit_sleep_summary_rapids_sumdurationinbedmain", "fitbit_sleep_summary_rapids_avgefficiencymain",
"fitbit_sleep_summary_rapids_avgdurationafterwakeupmain", "fitbit_sleep_summary_rapids_avgdurationasleepmain",
"fitbit_sleep_summary_rapids_avgdurationawakemain", "fitbit_sleep_summary_rapids_avgdurationtofallasleepmain",
"fitbit_sleep_summary_rapids_avgdurationinbedmain", "fitbit_sleep_summary_rapids_countepisodemain",
"fitbit_sleep_summary_rapids_firstbedtimemain", "fitbit_sleep_summary_rapids_lastbedtimemain",
"fitbit_sleep_summary_rapids_firstwaketimemain", "fitbit_sleep_summary_rapids_lastwaketimemain",
"fitbit_sleep_intraday_rapids_avgdurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_avgdurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_maxdurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_maxdurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_sumdurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_sumdurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_countepisodeasleepunifiedmain", "fitbit_sleep_intraday_rapids_countepisodeawakeunifiedmain",
"fitbit_sleep_intraday_rapids_stddurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_stddurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_mindurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_mindurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_mediandurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_mediandurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_ratiocountasleepunifiedwithinmain", "fitbit_sleep_intraday_rapids_ratiocountawakeunifiedwithinmain",
"fitbit_sleep_intraday_rapids_ratiodurationasleepunifiedwithinmain", "fitbit_sleep_intraday_rapids_ratiodurationawakeunifiedwithinmain",
],
"f_steps": [
"fitbit_steps_summary_rapids_maxsumsteps", "fitbit_steps_summary_rapids_minsumsteps",
"fitbit_steps_summary_rapids_avgsumsteps", "fitbit_steps_summary_rapids_mediansumsteps",
"fitbit_steps_summary_rapids_stdsumsteps",
"fitbit_steps_intraday_rapids_sumsteps", "fitbit_steps_intraday_rapids_maxsteps",
"fitbit_steps_intraday_rapids_minsteps", "fitbit_steps_intraday_rapids_avgsteps",
"fitbit_steps_intraday_rapids_stdsteps",
"fitbit_steps_intraday_rapids_countepisodesedentarybout", "fitbit_steps_intraday_rapids_sumdurationsedentarybout",
"fitbit_steps_intraday_rapids_maxdurationsedentarybout", "fitbit_steps_intraday_rapids_mindurationsedentarybout",
"fitbit_steps_intraday_rapids_avgdurationsedentarybout", "fitbit_steps_intraday_rapids_stddurationsedentarybout",
"fitbit_steps_intraday_rapids_countepisodeactivebout", "fitbit_steps_intraday_rapids_sumdurationactivebout",
"fitbit_steps_intraday_rapids_maxdurationactivebout", "fitbit_steps_intraday_rapids_mindurationactivebout",
"fitbit_steps_intraday_rapids_avgdurationactivebout", "fitbit_steps_intraday_rapids_stddurationactivebout",
],
}
TIME_SEGMENT = "allday"
WINDOW_DAYS = 28
LABEL_COL = "dep"
PREDICTION_TARGET = "dep_weekly"
CLASS_DICT = {True: "depressed", False: "not_depressed"}
PRE_SURVEY_COLS = [
"UCLA_10items_PRE", "SocialFit_PRE",
"2waySSS_receiving_emotional_PRE", "2waySSS_giving_emotional_PRE",
"2waySSS_giving_instrumental_PRE", "2waySSS_receiving_instrumental_PRE",
"ERQ_reappraisal_PRE", "ERQ_suppression_PRE",
"BRS_PRE", "CHIPS_PRE", "PSS_10items_PRE", "STAIS_PRE", "MAAS_7items_PRE",
"CESD_9items_PRE", "CESD_10items_PRE",
"BFI10_extroversion_PRE", "BFI10_agreeableness_PRE",
"BFI10_conscientiousness_PRE", "BFI10_neuroticism_PRE", "BFI10_openness_PRE",
]
def get_feature_col_names():
"""Build the list of full column names: f_type:feature_name:time_segment"""
cols = []
for ft in FEATURE_TYPES:
for feat in FEATURE_COLUMNS[ft]:
cols.append(f"{ft}:{feat}:{TIME_SEGMENT}")
return cols
def store_task_metadata(path):
task_metadata = {
"task": (
'Classify the user\'s depression status: ["depressed", "not_depressed"], '
"based on passive sensing data collected from a smartphone and a wearable fitness tracker."
),
"class": {
"depressed": "The user shows depressive symptoms based on self-reported weekly survey responses.",
"not_depressed": "The user does not show depressive symptoms based on self-reported weekly survey responses.",
},
"data": (
"Data were collected over a three-month study period from college students at a university. "
"Participants carried a smartphone and wore a Fitbit fitness tracker 24x7. "
"Passive sensing data includes GPS location, phone screen usage, Fitbit sleep, and Fitbit physical activity. "
"Features were extracted using the RAPIDS toolkit and computed daily over multiple time segments. "
"Each sample represents the last day of a 28-day observation window preceding a depression label date. "
"Each feature is named using the format 'sensor_type:feature_name:time_segment'."
),
"feature": (
"Location features (f_loc) include GPS-based metrics such as home time, distance travelled, "
"radius of gyration, location entropy, number of significant places, and time spent at various locations. "
"Phone usage features (f_screen) include unlock episode counts, durations, and location-specific phone usage patterns. "
"Sleep features (f_slp) include Fitbit-derived metrics such as sleep duration, efficiency, "
"time to fall asleep, bedtime/waketime, and intraday sleep/wake episode statistics. "
"Physical activity features (f_steps) include step counts, sedentary bout statistics, and active bout statistics. "
"All features use the 'allday' time segment (24 hours from midnight to midnight)."
),
}
with open(path, "w", encoding="utf-8") as f:
json.dump(task_metadata, f, indent=2)
def store_user_metadata(globem_path, out_path):
"""Build per-user metadata from platform.csv and pre.csv across all phases."""
user_metadata = {}
for institution, phases in PHASES.items():
for phase in phases:
ds_dir = os.path.join(globem_path, f"{institution}_{phase}")
platform_path = os.path.join(ds_dir, "ParticipantsInfoData", "platform.csv")
pre_path = os.path.join(ds_dir, "SurveyData", "pre.csv")
if not os.path.exists(platform_path):
print(f" Skipping {institution}_{phase}: platform.csv not found")
continue
df_platform = pd.read_csv(platform_path)
if "Unnamed: 0" in df_platform.columns:
df_platform = df_platform.drop(columns=["Unnamed: 0"])
df_platform = df_platform.set_index("pid")
df_pre = None
if os.path.exists(pre_path):
df_pre = pd.read_csv(pre_path)
if "Unnamed: 0" in df_pre.columns:
df_pre = df_pre.drop(columns=["Unnamed: 0"])
df_pre = df_pre.set_index("pid")
for pid in df_platform.index:
user_key = f"{pid}#{institution}_{phase}"
platform = str(df_platform.loc[pid, "platform"]).split(";")[0]
meta = {"platform": platform, "phase": phase}
if df_pre is not None and pid in df_pre.index:
row = df_pre.loc[pid]
for col in PRE_SURVEY_COLS:
if col in row.index:
val = row[col]
meta[col] = round(float(val), 4) if pd.notna(val) else None
user_metadata[user_key] = meta
with open(out_path, "w", encoding="utf-8") as f:
json.dump(user_metadata, f, indent=2)
def process_single_dataset(globem_path, institution, phase, feature_cols):
"""Process one dataset (institution + phase) and return per-user sample lists."""
ds_dir = os.path.join(globem_path, f"{institution}_{phase}")
rapids_path = os.path.join(ds_dir, "FeatureData", "rapids.csv")
label_path = os.path.join(ds_dir, "SurveyData", "dep_weekly.csv")
platform_path = os.path.join(ds_dir, "ParticipantsInfoData", "platform.csv")
for p in [rapids_path, label_path, platform_path]:
if not os.path.exists(p):
print(f" Skipping {institution}_{phase}: {os.path.basename(p)} not found")
return {}
print(f" Loading {institution}_{phase}...")
df_rapids = pd.read_csv(rapids_path, low_memory=False)
if "Unnamed: 0" in df_rapids.columns:
df_rapids = df_rapids.drop(columns=["Unnamed: 0"])
df_rapids["date"] = pd.to_datetime(df_rapids["date"])
df_label = pd.read_csv(label_path)
if "Unnamed: 0" in df_label.columns:
df_label = df_label.drop(columns=["Unnamed: 0"])
df_label["date"] = pd.to_datetime(df_label["date"])
df_label = df_label.dropna(subset=[LABEL_COL])
df_label = df_label.drop_duplicates(["pid", "date"], keep="last")
available_cols = [c for c in feature_cols if c in df_rapids.columns]
if len(available_cols) == 0:
print(f" Skipping {institution}_{phase}: no matching feature columns")
return {}
user_data = {}
phase_str = str(phase)
pids_few_response = df_label.groupby("pid")["date"].count()
valid_pids = set(pids_few_response[pids_few_response >= 2].index)
for _, row in tqdm(df_label.iterrows(), total=len(df_label),
desc=f" {institution}_{phase}", leave=False):
pid = row["pid"]
if pid not in valid_pids:
continue
date_end = row["date"]
date_start = date_end - timedelta(days=WINDOW_DAYS - 1)
label_val = row[LABEL_COL]
df_window = df_rapids[
(df_rapids["pid"] == pid)
& (df_rapids["date"] >= date_start)
& (df_rapids["date"] <= date_end)
]
if df_window.empty:
continue
last_day = df_window.sort_values("date").iloc[-1]
features = {}
for col in available_cols:
val = last_day[col]
if pd.notna(val):
features[col] = float(val)
else:
features[col] = None
user_key = f"{pid}#{institution}_{phase}"
if user_key not in user_data:
user_data[user_key] = []
user_data[user_key].append(dict(
user_id=user_key,
session_id=phase_str,
idx=len(user_data[user_key]),
date=str(date_end.date()),
label=CLASS_DICT[label_val],
features=features,
))
return user_data
def run(path=GLOBEM_PATH, out_dir=OUT_DIR):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
feature_cols = get_feature_col_names()
print(f"Using {len(feature_cols)} feature columns ({TIME_SEGMENT} time segment)")
print("Saving task metadata...")
store_task_metadata(os.path.join(out_dir, "task_metadata.json"))
print("Saving user metadata...")
store_user_metadata(path, os.path.join(out_dir, "user_metadata.json"))
all_user_data = {}
for institution, phases in PHASES.items():
for phase in phases:
user_data = process_single_dataset(path, institution, phase, feature_cols)
for user_key, samples in user_data.items():
all_user_data.setdefault(user_key, []).extend(samples)
print(f"\nSaving datasets for {len(all_user_data)} users...")
total_samples = 0
for user_key, data in sorted(all_user_data.items()):
safe_name = user_key.replace("#", "_")
user_dir = os.path.join(out_dir, safe_name)
dataset = Dataset.from_list(data)
dataset.save_to_disk(user_dir)
total_samples += len(data)
print(f"\nDone. {total_samples} total samples from {len(all_user_data)} users saved to {out_dir}")
if __name__ == "__main__":
Fire(run)

View File

@@ -4,6 +4,7 @@ import json
import warnings
import numpy as np
import neurokit2 as nk
import pandas as pd
from tqdm import tqdm
from fire import Fire
@@ -21,7 +22,7 @@ warnings.simplefilter("ignore", NeuroKitWarning)
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
SLEEPEDF_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/sleep-cassette/"
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new"
EPOCH_SEC_SIZE = 30
SAMPLING_RATE = 100
@@ -45,8 +46,8 @@ ann2label = {
}
def store_info(info_path):
info = {
def store_task_metadata(path):
task_metadata = {
"task": 'Classify the user\'s sleep stage: ["W", "N1", "N2", "N3", "REM"], based on physiological signals collected from wearable sensors.',
"class": {
"W": "Wakefulness. This includes periods before sleep onset or after final awakening, and short awakenings during the night.",
@@ -72,8 +73,21 @@ def store_info(info_path):
"Ratio features such as delta/theta, theta/alpha, alpha/beta, and (delta+theta)/(alpha+beta) were also included."
),
}
with open(info_path, "w", encoding="utf-8") as f:
json.dump(info, f, indent=2)
with open(path, "w", encoding="utf-8") as f:
json.dump(task_metadata, f, indent=2)
def store_user_metadata(src_path, out_path):
# load xls file
df = pd.read_excel(src_path)
user_metadata = {}
for _, row in df.iterrows():
user_id = str(row["subject"])
age = int(row["age"])
sex = int(row["sex (F=1)"])-1 # 0: female, 1: male
user_metadata[user_id] = {"age": age, "sex": sex}
with open(out_path, "w", encoding="utf-8") as f:
json.dump(user_metadata, f, indent=2)
def lowpass_filter(data, cutoff=50, fs=1000, order=4):
@@ -197,122 +211,7 @@ def process_by_mod(modality, data):
features[f"{modality}_theta/alpha_ratio"] = theta_alpha_ratio
features[f"{modality}_alpha/beta_ratio"] = alpha_beta_ratio
features[f"{modality}_(delta+theta)/(alpha+beta)_ratio"] = slow_fast_ratio
# elif "EOG" in modality:
# eog_cleaned = data
# try:
# eog_cleaned = nk.eog_clean(data, sr)
# except IndexError as e:
# print(f"Error processing EOG data for {modality}: {e}")
# return None
# eog_mean = np.mean(eog_cleaned)
# eog_std = np.std(eog_cleaned)
# eog_var = np.var(eog_cleaned)
# features[f"{modality}_mean"] = eog_mean
# features[f"{modality}_std"] = eog_std
# features[f"{modality}_variance"] = eog_var
# dynamic_range = np.max(eog_cleaned) - np.min(eog_cleaned)
# features[f"{modality}_dynamic_range"] = dynamic_range
# peaks = signal.find_peaks(eog_cleaned - eog_mean, height=3 * eog_std)[0]
# features[f"{modality}_num_peaks"] = len(peaks)
# zero_crossings = np.where(np.diff(np.sign(eog_cleaned - eog_mean)))[0]
# features[f"{modality}_num_zero_crossings"] = len(zero_crossings)
# differences = eog_cleaned[1:] - eog_cleaned[:-1]
# difference_variance = np.var(differences)
# features[f"{modality}_difference_variance"] = difference_variance
# features[f"{modality}_num_large_eye_movements"] = count_large_eye_movements(
# eog_cleaned, sr, amp_thresh=120, time_thresh=1.5
# )
# eog_large_movement_removed = remove_large_eye_movements(
# eog_cleaned, fs=sr, amp_thresh=120, time_thresh=1.5, pad=0.75
# )
# differences = eog_large_movement_removed[1:] - eog_large_movement_removed[:-1]
# difference_variance = np.var(differences)
# features[f"{modality}_difference_variance_without_large_movements"] = (
# difference_variance
# )
# freqs, psd = signal.welch(eog_cleaned, fs=sr, nperseg=sr * 2)
# total_idx = np.logical_and(freqs >= 0.5, freqs <= 30)
# total_power = np.trapezoid(psd[total_idx], freqs[total_idx])
# slow_idx = np.logical_and(freqs >= 0.5, freqs <= 2)
# rapid_idx = np.logical_and(freqs >= 2, freqs <= 5)
# slow_power = np.trapezoid(psd[slow_idx], freqs[slow_idx])
# rapid_power = np.trapezoid(psd[rapid_idx], freqs[rapid_idx])
# slow_power_ratio = slow_power / total_power if total_power > 0 else 0
# rapid_power_ratio = rapid_power / total_power if total_power > 0 else 0
# features[f"{modality}_slow_movement_power_ratio"] = slow_power_ratio
# features[f"{modality}_rapid_movement_power_ratio"] = rapid_power_ratio
# elif "Resp" in modality:
# rsp_signals = data
# try:
# rsp_signals, _ = nk.rsp_process(data, sampling_rate=sr, method="biosppy")
# except IndexError as e:
# print(f"Error processing respiration data for {modality}: {e}")
# return None
# clean = rsp_signals["RSP_Clean"]
# phase = rsp_signals["RSP_Phase"]
# rate = rsp_signals["RSP_Rate"]
# amplitude = rsp_signals["RSP_Amplitude"]
# peaks = np.where(rsp_signals["RSP_Peaks"] == 1)[0]
# troughs = np.where(rsp_signals["RSP_Troughs"] == 1)[0]
# inhale_durations = []
# for t in troughs:
# next_peaks = peaks[peaks > t]
# if len(next_peaks) == 0:
# continue
# inhale_durations.append((next_peaks[0] - t) / sr)
# inhale_durations = np.array(inhale_durations)
# exhale_durations = []
# for p in peaks:
# next_troughs = troughs[troughs > p]
# if len(next_troughs) == 0:
# continue
# exhale_durations.append((next_troughs[0] - p) / sr)
# exhale_durations = np.array(exhale_durations)
# features[f"{modality}_inhale_duration_mean"] = np.mean(inhale_durations)
# features[f"{modality}_inhale_duration_std"] = np.std(inhale_durations)
# features[f"{modality}_exhale_duration_mean"] = np.mean(exhale_durations)
# features[f"{modality}_exhale_duration_std"] = np.std(exhale_durations)
# features[f"{modality}_inhale_exhale_ratio"] = (
# np.mean(inhale_durations) / np.mean(exhale_durations)
# if np.mean(exhale_durations) > 0
# else np.nan
# )
# features[f"{modality}_stretch"] = np.max(clean) - np.min(clean)
# inhale_mask = phase == 1
# features[f"{modality}_inspiration_volume"] = np.trapezoid(
# amplitude[inhale_mask], dx=1 / sr
# )
# features[f"{modality}_respiration_rate"] = np.mean(rate)
# resp_durations = np.diff(troughs) / sr
# features[f"{modality}_respiration_duration"] = np.mean(resp_durations)
# elif "EMG" in modality:
# emg_mean = np.mean(data)
# emg_std = np.std(data)
# features[f"{modality}_mean"] = emg_mean
# features[f"{modality}_std"] = emg_std
# features[f"{modality}_dynamic_range"] = np.max(data) - np.min(data)
# features[f"{modality}_absolute_integral"] = np.sum(np.abs(data)) / sr
# features[f"{modality}_median"] = np.median(data)
# features[f"{modality}_10th_percentile"] = np.percentile(data, 10)
# features[f"{modality}_90th_percentile"] = np.percentile(data, 90)
# peaks, _ = signal.find_peaks(data, height=3 * emg_std)
# peak_values = data[peaks]
# features[f"{modality}_num_peaks"] = len(peaks)
# features[f"{modality}_peak_amplitude_mean"] = (
# np.mean(peak_values) if len(peak_values) > 0 else 0
# )
# features[f"{modality}_peak_amplitude_std"] = (
# np.std(peak_values) if len(peak_values) > 0 else 0
# )
# features[f"{modality}_peak_amplitude_sum"] = (
# np.sum(peak_values) if len(peak_values) > 0 else 0
# )
# features[f"{modality}_peak_amplitude_norm_sum"] = (
# np.sum(peak_values) / np.sum(np.abs(data))
# if np.sum(np.abs(data)) > 0
# else 0
# )
return features
@@ -450,9 +349,17 @@ def run(path=SLEEPEDF_PATH, out_dir=OUT_DIR, num_examples=1, num_workers=32, see
np.random.seed(seed)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
info_path = os.path.join(out_dir, "info.json")
store_info(info_path)
print(f"Saved info to {info_path}")
task_metadata_out_path = os.path.join(out_dir, "task_metadata.json")
store_task_metadata(task_metadata_out_path)
print(f"Saved info to {task_metadata_out_path}")
user_metadata_out_path = os.path.join(out_dir, "user_metadata.json")
user_metadata_src_path = os.path.join(path, "..", "SC-subjects.xls")
store_user_metadata(user_metadata_src_path, user_metadata_out_path)
print(f"Saved info to {user_metadata_out_path}")
psg_file_paths = glob(os.path.join(path, "*PSG.edf"))
ann_file_paths = glob(os.path.join(path, "*Hypnogram.edf"))
psg_file_paths.sort()
@@ -466,16 +373,20 @@ def run(path=SLEEPEDF_PATH, out_dir=OUT_DIR, num_examples=1, num_workers=32, see
elif basename.startswith("SC41"):
filtered_2013_file_paths.append(file_path)
user_data = {}
with Pool(processes=num_workers) as pool:
for data in pool.imap_unordered(preprocess, filtered_2013_file_paths):
if len(data) == 0:
continue
user_id = data[0]["user_id"]
session_id = data[0]["session_id"]
dataset = Dataset.from_list(data)
test_dir = os.path.join(out_dir, f"{user_id}", f"{session_id}")
dataset.save_to_disk(test_dir)
print(f"Saved dataset to {test_dir}")
user_data.setdefault(user_id, []).extend(data)
for user_id, data in user_data.items():
dataset = Dataset.from_list(data)
test_dir = os.path.join(out_dir, f"{user_id}")
dataset.save_to_disk(test_dir)
print(f"Saved dataset to {test_dir} ({len(data)} samples, "
f"{len(set(d['session_id'] for d in data))} session(s))")
if __name__ == "__main__":

370
run.py
View File

@@ -1,294 +1,114 @@
import os
import re
import asyncio
import yaml
import json
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any, Optional
import time
from glob import glob
import yaml
from fire import Fire
from core.model import load_models
from core.data_loader import DataLoader
from core.sensing_agent import SensingAgent
from core.embedding_index import EmbeddingIndex, create_embedding_index
_EMBEDDING_INDEX: Optional[EmbeddingIndex] = None
def init_embedding_index(embedding_path: Optional[str]) -> Optional[EmbeddingIndex]:
"""Initialize global embedding index for similarity-based selection."""
global _EMBEDDING_INDEX
if embedding_path is None:
_EMBEDDING_INDEX = None
return None
if _EMBEDDING_INDEX is not None:
return _EMBEDDING_INDEX
_EMBEDDING_INDEX = create_embedding_index(embedding_path)
return _EMBEDDING_INDEX
from core.example_queue import ExampleQueue
from core.model import load_models
from core.recruiter import Recruiter
from core.agent import Agent
from core.logger import Logger
from core.prompt import gen_system_message, gen_task_message
from core.json_utils import safe_parse_json
from core.scores import self_certainty
from core.vote import borda_vote
def load_user_data_sync(
data_path: str,
user: str,
seed: int,
log_path_base: str,
selection_criteria: str = "out_random",
num_examples: int = 1,
embedding_index: Optional[EmbeddingIndex] = None,
) -> List[Dict[str, Any]]:
"""Load data for a single user with specified selection criteria."""
print(f"[DATA LOADING] Starting: user={user}, seed={seed}, criteria={selection_criteria}")
# Set random seed for reproducibility
np.random.seed(seed)
data_loader = DataLoader(
data_path,
user,
selection_criteria=selection_criteria,
num_examples=num_examples,
embedding_index=embedding_index,
async def run(config_path: str):
print("[Main] Loading config")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
logger = Logger(config.get("log_path"))
logger.log_config(config)
logger.log("[Main] Loaded config")
logger.log("[Main] Loading data loader")
dataloader = DataLoader(config.get("data_path"), config.get("target_user"))
logger.log("[Main] Loaded data loader")
logger.log("[Main] Initializing example queue")
example_q = ExampleQueue(
queue_size=config.get("queue_size"),
logger=logger,
)
if not data_loader.is_valid:
print(f"[DATA LOADING] Skipping invalid user: {user}")
return []
tasks = []
idx = 0
dataset_size = len(data_loader)
print(f"[DATA LOADING] User {user} has {dataset_size} samples")
for sample, examples in data_loader:
if idx % 10 != 0:
idx += 1
continue
log_path = os.path.join(log_path_base, user, f"{idx:02d}", str(seed))
os.makedirs(log_path, exist_ok=True)
kwargs = {
"task_info": data_loader.get_task_info(),
"classes_info": data_loader.get_classes_info(),
"sensor_info": data_loader.get_sensor_info(),
"sample": sample,
"examples": examples,
"log_path": log_path,
"ground_truth": sample["label"],
}
tasks.append(kwargs)
idx += 1
print(f"[DATA LOADING] Completed user {user}, seed {seed}: {len(tasks)} tasks")
return tasks
recruiter = Recruiter(
source_dataset=dataloader.get_source_dataset(),
source_users=dataloader.get_source_users(),
num_shot=config.get("num_shot"),
classes=dataloader.get_classes(),
logger=logger,
)
list_examples = recruiter.recruit(config.get("queue_size"))
example_q.update([], list_examples)
logger.log("[Main] Initialized example queue")
logger.log("[Main] Loading model pool")
model_pool = load_models(config.get("model_paths"))
logger.log("[Main] Loaded model pool")
async def load_data_parallel(config: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Load data for all users in parallel with configurable selection criteria."""
print("[DATA LOADING] Starting parallel data loading...")
# Get configuration parameters
selection_criteria = config.get("selection_criteria", "out_random")
embedding_path = config.get("embedding_path", None)
num_examples = config.get("num_examples", 1)
# Initialize embedding index if needed for similarity-based selection
embedding_index = None
if selection_criteria in ["out_similar", "in_similar"]:
if embedding_path is None:
print(f"[WARNING] {selection_criteria} requires embedding_path in config")
print("[WARNING] Falling back to random selection")
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
else:
embedding_index = init_embedding_index(embedding_path)
if embedding_index is None:
print(f"[WARNING] Failed to load embeddings, falling back to random")
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
user_paths = glob(os.path.join(config["data_path"], "*"))
user_paths = [path for path in user_paths if os.path.isdir(path)]
users = [path.split("/")[-1] for path in user_paths]
print(f"[DATA LOADING] Found {len(users)} users: {users}")
print(f"[DATA LOADING] Selection criteria: {selection_criteria}")
print(f"[DATA LOADING] Num examples per class: {num_examples}")
max_workers = config.get("data_workers", 96)
print(f"[DATA LOADING] Using {max_workers} workers for data loading")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
loop = asyncio.get_event_loop()
futures = []
for seed in range(config["num_seeds"]):
for user in users:
future = loop.run_in_executor(
executor,
load_user_data_sync,
config["data_path"],
user,
seed,
config["log_path"],
selection_criteria,
num_examples,
embedding_index,
)
futures.append(future)
print(f"[DATA LOADING] Created {len(futures)} parallel data loading tasks")
results = await asyncio.gather(*futures, return_exceptions=True)
# Flatten results and filter out exceptions
all_tasks = []
for result in results:
if isinstance(result, Exception):
print(f"[DATA LOADING] Error loading data: {result}")
else:
all_tasks.extend(result)
logger.log("[Main] Initializing agent")
agent = Agent(
model_pool=model_pool,
logger=logger,
)
system_message = gen_system_message(metadata=dataloader.get_task_metadata())
agent.set_system_message(system_message)
logger.log("[Main] Initialized agent")
print(f"[DATA LOADING] Total tasks: {len(all_tasks)}")
return all_tasks
async def process_example_set(sample_idx, example_idx, example_set, sample):
logger.log(f"[Main] Processing {sample_idx} - {example_idx} (queue index)")
try:
task_message = gen_task_message(sample, example_set)
response, logprobs = await agent.solve(
task_message, sample_idx, example_idx
)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(
f"[Main] Done {sample_idx} - {example_idx}: "
f"answer={answer}, score={score:.4f}"
)
return {"example_set": example_set, "answer": answer, "score": score}
except Exception as e:
logger.log(
f"[Main] Error {sample_idx} - {example_idx}: {e}",
filename="errors.txt",
)
return {
"example_set": example_set,
"answer": None,
"score": float("-inf"),
}
async def run_parallel(kwargs_list: List[Dict[str, Any]], model_pool, config: Dict[str, Any]) -> None:
"""Run classification tasks in parallel using the model pool."""
print(f"[EXECUTION] Starting {len(kwargs_list)} classification tasks...")
tasks = []
for kwargs in kwargs_list:
agent = SensingAgent(
name="EEG sensing",
model_pool=model_pool,
task_info=kwargs["task_info"],
classes_info=kwargs["classes_info"],
sensor_info=kwargs["sensor_info"],
sample=kwargs["sample"],
examples=kwargs["examples"],
log_path=kwargs["log_path"],
)
task = asyncio.create_task(agent.solve(kwargs["sample"], kwargs["examples"], kwargs["ground_truth"]))
tasks.append(task)
await asyncio.gather(*tasks)
print("[EXECUTION] All tasks completed")
def run(config_path: str) -> None:
"""
Main entry point for running experiments.
Args:
config_path: Path to YAML configuration file
"""
print(f"[MAIN] Loading config from: {config_path}")
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
print("=" * 60)
print("EXPERIMENT CONFIGURATION")
print("=" * 60)
print(f" Data path: {config.get('data_path', 'N/A')}")
print(f" Log path: {config.get('log_path', 'N/A')}")
print(f" Selection criteria: {config.get('selection_criteria', 'out_random')}")
print(f" Num examples: {config.get('num_examples', 1)}")
print(f" Num seeds: {config.get('num_seeds', 1)}")
print(f" Embedding path: {config.get('embedding_path', 'N/A')}")
print(f" Num models: {len(config.get('models', []))}")
print("=" * 60)
model_pool = load_models(config["models"])
kwargs_list = asyncio.run(load_data_parallel(config))
if len(kwargs_list) == 0:
print("[ERROR] No valid tasks to run. Check data paths and configuration.")
return
# Warmup models
print("[MAIN] Warming up models...")
model_pool.warmup()
# Run experiments
print("[MAIN] Starting experiments...")
logger.log("[Main] Starting main loop")
start_time = time.time()
asyncio.run(run_parallel(kwargs_list, model_pool, config))
elapsed = time.time() - start_time
print(f"[MAIN] Experiment completed in {elapsed:.2f} seconds")
print(f"[MAIN] Results saved to: {config['log_path']}")
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Processing sample {idx} / {len(dataloader)} (sample index)")
tasks = [
process_example_set(idx, example_idx, example_set, sample)
for example_idx, example_set in enumerate(example_q)
]
results = await asyncio.gather(*tasks)
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
ground_truth = sample["label"]
if winner is not None:
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
logger.log(f"[Vote] votes={{ {tally_str} }}")
logger.log_result(idx, winner, ground_truth)
else:
logger.log(f"[Vote] Sample {idx} | no valid answer parsed, skipping")
recruiter.update_strategy(results)
list_examples = recruiter.recruit(num_example_set=config.get("update_size"))
example_q.update(results, list_examples)
logger.report(elapsed_seconds=time.time() - start_time)
def run_comparison(
base_config_path: str,
criteria_list: str = "out_random,in_random,out_similar,in_similar",
embedding_path: str = None,
) -> None:
"""
Run experiments comparing multiple selection criteria.
Args:
base_config_path: Path to base YAML config file
criteria_list: Comma-separated list of selection criteria to compare
embedding_path: Path to embeddings (required for *_similar criteria)
Example:
python run.py compare config/sleepedf.yaml \\
--criteria_list="out_random,out_similar" \\
--embedding_path="./embeddings_full"
"""
base_config = yaml.load(open(base_config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
criteria = [c.strip() for c in criteria_list.split(",")]
print("=" * 60)
print("COMPARISON EXPERIMENT")
print("=" * 60)
print(f" Selection criteria to compare: {criteria}")
print(f" Embedding path: {embedding_path}")
print("=" * 60)
for criterion in criteria:
print(f"\n{'='*60}")
print(f"Running experiment: {criterion}")
print(f"{'='*60}")
config = base_config.copy()
config["selection_criteria"] = criterion
base_log_path = config["log_path"]
if not base_log_path.endswith(criterion):
config["log_path"] = f"{base_log_path}_{criterion}"
if criterion in ["out_similar", "in_similar"]:
if embedding_path:
config["embedding_path"] = embedding_path
else:
print(f"[WARNING] {criterion} requires --embedding_path argument")
continue
model_pool = load_models(config["models"])
kwargs_list = asyncio.run(load_data_parallel(config))
if len(kwargs_list) == 0:
print(f"[WARNING] No tasks for {criterion}, skipping")
continue
model_pool.warmup()
asyncio.run(run_parallel(kwargs_list, model_pool, config))
print(f"[DONE] Results for {criterion} saved to: {config['log_path']}")
print("\n" + "=" * 60)
print("ALL COMPARISON EXPERIMENTS COMPLETED")
print("=" * 60)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire({
"run": run,
"compare": run_comparison,
})
Fire(main)

View File

@@ -1,412 +0,0 @@
"""
Self-Consistency Results Analysis Script
This module provides analysis tools for Self-Consistency experiment results:
- Load and aggregate results from experiment directories
- Compute detailed statistics and metrics
- Generate visualizations and reports
- Compare results across different configurations
Usage:
# Analyze single experiment
python -m sc.analysis.analyze_sc_results analyze /path/to/results
# Compare multiple experiments
python -m sc.analysis.analyze_sc_results compare /path/to/exp1 /path/to/exp2
# Generate summary report
python -m sc.analysis.analyze_sc_results report /path/to/results --output report.md
Author: NMSL Research Team
Date: 2026-01-21
"""
import os
import json
import yaml
import numpy as np
import pandas as pd
from glob import glob
from typing import List, Dict, Any, Optional, Tuple
from collections import defaultdict
from fire import Fire
class SCResultsAnalyzer:
"""
Analyzer for Self-Consistency experiment results.
Loads experiment results and provides various analysis methods
including accuracy computation, consistency analysis, and
per-class/per-user breakdowns.
Attributes:
results_path: Path to experiment results directory
results: List of result dictionaries
config: Experiment configuration
stats: Computed statistics
"""
def __init__(self, results_path: str):
"""
Initialize analyzer with results directory.
Args:
results_path: Path to experiment results directory
Should contain all_results.json and statistics.json
"""
self.results_path = results_path
self.results = []
self.config = {}
self.stats = {}
self._load_results()
def _load_results(self) -> None:
"""Load results, config, and statistics from files."""
# Load all results
results_file = os.path.join(self.results_path, "all_results.json")
if os.path.exists(results_file):
with open(results_file, "r", encoding="utf-8") as f:
self.results = json.load(f)
print(f"[LOAD] Loaded {len(self.results)} results from {results_file}")
else:
print(f"[WARNING] Results file not found: {results_file}")
# Load config
config_file = os.path.join(self.results_path, "config.yaml")
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
self.config = yaml.safe_load(f)
print(f"[LOAD] Loaded config from {config_file}")
# Load pre-computed statistics
stats_file = os.path.join(self.results_path, "statistics.json")
if os.path.exists(stats_file):
with open(stats_file, "r", encoding="utf-8") as f:
self.stats = json.load(f)
print(f"[LOAD] Loaded statistics from {stats_file}")
def get_dataframe(self) -> pd.DataFrame:
"""
Convert results to pandas DataFrame.
Returns:
DataFrame with one row per result
"""
if not self.results:
return pd.DataFrame()
return pd.DataFrame(self.results)
def compute_accuracy(self) -> Dict[str, float]:
"""
Compute overall and per-class accuracy.
Returns:
Dictionary with 'overall' and per-class accuracies
"""
if not self.results:
return {"overall": 0.0}
df = self.get_dataframe()
# Overall accuracy
overall = df["is_correct"].mean()
# Per-class accuracy
accuracy = {"overall": overall}
for cls in df["ground_truth"].unique():
cls_df = df[df["ground_truth"] == cls]
accuracy[cls] = cls_df["is_correct"].mean()
return accuracy
def compute_consistency_analysis(self) -> Dict[str, Any]:
"""
Analyze relationship between consistency and accuracy.
Returns:
Dictionary with consistency statistics and correlations
"""
if not self.results:
return {}
df = self.get_dataframe()
# Consistency distribution
consistency_mean = df["consistency"].mean()
consistency_std = df["consistency"].std()
# High consistency accuracy
high_cons = df[df["consistency"] >= 0.8]
low_cons = df[df["consistency"] < 0.8]
high_cons_acc = high_cons["is_correct"].mean() if len(high_cons) > 0 else 0
low_cons_acc = low_cons["is_correct"].mean() if len(low_cons) > 0 else 0
# Consistency bins
bins = [0.0, 0.4, 0.6, 0.8, 1.0]
labels = ["0.0-0.4", "0.4-0.6", "0.6-0.8", "0.8-1.0"]
df["consistency_bin"] = pd.cut(df["consistency"], bins=bins, labels=labels)
bin_stats = {}
for label in labels:
bin_df = df[df["consistency_bin"] == label]
bin_stats[label] = {
"count": len(bin_df),
"accuracy": bin_df["is_correct"].mean() if len(bin_df) > 0 else 0,
}
return {
"consistency_mean": consistency_mean,
"consistency_std": consistency_std,
"high_consistency_count": len(high_cons),
"high_consistency_accuracy": high_cons_acc,
"low_consistency_count": len(low_cons),
"low_consistency_accuracy": low_cons_acc,
"bin_statistics": bin_stats,
}
def compute_per_user_accuracy(self) -> Dict[str, Dict[str, Any]]:
"""
Compute accuracy breakdown by user.
Returns:
Dictionary mapping user_id to accuracy metrics
"""
if not self.results:
return {}
df = self.get_dataframe()
user_stats = {}
for user_id in df["user_id"].unique():
user_df = df[df["user_id"] == user_id]
user_stats[user_id] = {
"count": len(user_df),
"accuracy": user_df["is_correct"].mean(),
"avg_consistency": user_df["consistency"].mean(),
"avg_confidence": user_df["confidence"].mean(),
}
return user_stats
def compute_confusion_matrix(self) -> Tuple[np.ndarray, List[str]]:
"""
Compute confusion matrix.
Returns:
Tuple of (confusion_matrix, class_labels)
"""
if not self.results:
return np.array([]), []
df = self.get_dataframe()
classes = sorted(df["ground_truth"].unique())
matrix = np.zeros((len(classes), len(classes)), dtype=int)
class_to_idx = {cls: i for i, cls in enumerate(classes)}
for _, row in df.iterrows():
gt_idx = class_to_idx[row["ground_truth"]]
pred = row["answer"]
if pred in class_to_idx:
pred_idx = class_to_idx[pred]
matrix[gt_idx, pred_idx] += 1
return matrix, classes
def generate_report(self) -> str:
"""
Generate comprehensive markdown report.
Returns:
Markdown formatted report string
"""
report = []
report.append("# Self-Consistency Experiment Results\n")
# Config summary
report.append("## Experiment Configuration\n")
if self.config:
report.append(f"- Selection Criteria: {self.config.get('selection_criteria', 'N/A')}")
report.append(f"- Num ICL Examples: {self.config.get('num_examples', 'N/A')}")
report.append(f"- Num SC Samples: {self.config.get('num_sc_samples', 'N/A')}")
report.append(f"- Temperature: {self.config.get('temperature', 'N/A')}")
report.append(f"- Num Seeds: {self.config.get('num_seeds', 'N/A')}")
report.append("")
# Overall statistics
report.append("## Overall Statistics\n")
accuracy = self.compute_accuracy()
report.append(f"- **Overall Accuracy**: {accuracy['overall']:.4f}")
report.append(f"- **Total Samples**: {len(self.results)}")
if self.stats:
report.append(f"- **Avg Confidence**: {self.stats.get('avg_confidence', 0):.4f}")
report.append(f"- **Avg Consistency**: {self.stats.get('avg_consistency', 0):.4f}")
report.append("")
# Per-class accuracy
report.append("## Per-Class Accuracy\n")
report.append("| Class | Accuracy |")
report.append("|-------|----------|")
for cls, acc in sorted(accuracy.items()):
if cls != "overall":
report.append(f"| {cls} | {acc:.4f} |")
report.append("")
# Consistency analysis
report.append("## Consistency Analysis\n")
cons_analysis = self.compute_consistency_analysis()
if cons_analysis:
report.append(f"- **Mean Consistency**: {cons_analysis['consistency_mean']:.4f}")
report.append(f"- **Consistency Std**: {cons_analysis['consistency_std']:.4f}")
report.append(f"- **High Consistency (≥0.8) Accuracy**: {cons_analysis['high_consistency_accuracy']:.4f}")
report.append(f"- **Low Consistency (<0.8) Accuracy**: {cons_analysis['low_consistency_accuracy']:.4f}")
report.append("\n### Accuracy by Consistency Bin\n")
report.append("| Consistency Range | Count | Accuracy |")
report.append("|-------------------|-------|----------|")
for bin_label, stats in cons_analysis["bin_statistics"].items():
report.append(f"| {bin_label} | {stats['count']} | {stats['accuracy']:.4f} |")
report.append("")
# Confusion matrix
report.append("## Confusion Matrix\n")
matrix, classes = self.compute_confusion_matrix()
if len(classes) > 0:
header = "| | " + " | ".join(classes) + " |"
separator = "|---" * (len(classes) + 1) + "|"
report.append(header)
report.append(separator)
for i, cls in enumerate(classes):
row = f"| **{cls}** | " + " | ".join(str(x) for x in matrix[i]) + " |"
report.append(row)
report.append("")
return "\n".join(report)
def save_report(self, output_path: str) -> None:
"""
Save report to file.
Args:
output_path: Path to save the report
"""
report = self.generate_report()
with open(output_path, "w", encoding="utf-8") as f:
f.write(report)
print(f"[SAVE] Report saved to: {output_path}")
def compare_experiments(
paths: List[str],
output_path: Optional[str] = None
) -> pd.DataFrame:
"""
Compare results across multiple experiments.
Args:
paths: List of paths to experiment result directories
output_path: Optional path to save comparison CSV
Returns:
DataFrame with comparison metrics
"""
comparison = []
for path in paths:
analyzer = SCResultsAnalyzer(path)
accuracy = analyzer.compute_accuracy()
cons = analyzer.compute_consistency_analysis()
comparison.append({
"experiment": os.path.basename(path),
"selection_criteria": analyzer.config.get("selection_criteria", "N/A"),
"num_samples": len(analyzer.results),
"accuracy": accuracy["overall"],
"avg_consistency": cons.get("consistency_mean", 0),
"high_cons_accuracy": cons.get("high_consistency_accuracy", 0),
})
df = pd.DataFrame(comparison)
if output_path:
df.to_csv(output_path, index=False)
print(f"[SAVE] Comparison saved to: {output_path}")
return df
# =============================================================================
# CLI Commands
# =============================================================================
def analyze(results_path: str) -> None:
"""
Analyze single experiment results.
Args:
results_path: Path to experiment results directory
"""
analyzer = SCResultsAnalyzer(results_path)
print("\n" + "=" * 60)
print("EXPERIMENT ANALYSIS")
print("=" * 60)
# Accuracy
accuracy = analyzer.compute_accuracy()
print(f"\nOverall Accuracy: {accuracy['overall']:.4f}")
print("\nPer-Class Accuracy:")
for cls, acc in sorted(accuracy.items()):
if cls != "overall":
print(f" {cls}: {acc:.4f}")
# Consistency
cons = analyzer.compute_consistency_analysis()
print(f"\nConsistency Analysis:")
print(f" Mean: {cons['consistency_mean']:.4f}")
print(f" Std: {cons['consistency_std']:.4f}")
print(f" High (≥0.8) Accuracy: {cons['high_consistency_accuracy']:.4f}")
print("=" * 60)
def report(results_path: str, output: str = None) -> None:
"""
Generate and optionally save analysis report.
Args:
results_path: Path to experiment results directory
output: Optional path to save report (default: results_path/report.md)
"""
analyzer = SCResultsAnalyzer(results_path)
if output is None:
output = os.path.join(results_path, "report.md")
analyzer.save_report(output)
print(f"Report generated: {output}")
def compare(*paths: str, output: str = None) -> None:
"""
Compare multiple experiment results.
Args:
paths: Paths to experiment result directories
output: Optional path to save comparison CSV
"""
df = compare_experiments(list(paths), output)
print("\n" + df.to_string(index=False))
if __name__ == "__main__":
Fire({
"analyze": analyze,
"report": report,
"compare": compare,
})

View File

@@ -1,73 +0,0 @@
# ==============================================================================
# Sleep Stage Classification - Confidence-based Queue Policy Experiment
# ==============================================================================
# ------------------------------------------------------------------------------
# Data Configuration
# ------------------------------------------------------------------------------
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
data_workers: 16
# ------------------------------------------------------------------------------
# Experiment Settings (to be overridden by CLI)
# ------------------------------------------------------------------------------
# These will be set via command line arguments
user_id: 5 # Target user (5 or 10)
shuffle_seed: 42 # Shuffle seed (42, 123, or 456)
# Queue and ICL settings
queue_size: 5
num_icl_shots: 5 # Number of ICL examples per agent
# Self-Consistency settings
num_sc_samples: 8 # Number of SC sampling agents
# Example pool selection: "out" (different users) or "in" (same user)
example_pool: "out"
# Process all samples (no sampling)
sample_rate: 1
# Tracking window size for rolling window accuracy
tracking_window: 20
# Model context window size
num_ctx: 15000
# Temperature for LLM
temperature: 0.0
# ------------------------------------------------------------------------------
# Queue Policy
# ------------------------------------------------------------------------------
queue_policy: "confidence"
# Options: "confidence", "consistency", "random"
# This config is for CONFIDENCE-based queue updates
# ------------------------------------------------------------------------------
# Model Configuration (8 agents)
# ------------------------------------------------------------------------------
models:
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
# ------------------------------------------------------------------------------
# Output Configuration
# ------------------------------------------------------------------------------
log_path: "/mnt/sting/ssum/sleepedf_sc_experiment/confidence"
# ------------------------------------------------------------------------------
# Sleep Stages
# ------------------------------------------------------------------------------
stages:
- W
- N1
- N2
- N3
- REM

View File

@@ -1,73 +0,0 @@
# ==============================================================================
# Sleep Stage Classification - Consistency-based Queue Policy Experiment
# ==============================================================================
# ------------------------------------------------------------------------------
# Data Configuration
# ------------------------------------------------------------------------------
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
data_workers: 16
# ------------------------------------------------------------------------------
# Experiment Settings (to be overridden by CLI)
# ------------------------------------------------------------------------------
# These will be set via command line arguments
user_id: 5 # Target user (5 or 10)
shuffle_seed: 42 # Shuffle seed (42, 123, or 456)
# Queue and ICL settings
queue_size: 5
num_icl_shots: 5 # Number of ICL examples per agent
# Self-Consistency settings
num_sc_samples: 8 # Number of SC sampling agents
# Example pool selection: "out" (different users) or "in" (same user)
example_pool: "out"
# Process all samples (no sampling)
sample_rate: 1
# Tracking window size for rolling window accuracy
tracking_window: 20
# Model context window size
num_ctx: 15000
# Temperature for LLM
temperature: 0.0
# ------------------------------------------------------------------------------
# Queue Policy
# ------------------------------------------------------------------------------
queue_policy: "consistency"
# Options: "confidence", "consistency", "random"
# This config is for CONSISTENCY-based queue updates
# ------------------------------------------------------------------------------
# Model Configuration (8 agents)
# ------------------------------------------------------------------------------
models:
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
# ------------------------------------------------------------------------------
# Output Configuration
# ------------------------------------------------------------------------------
log_path: "/mnt/sting/ssum/sleepedf_sc_experiment/consistency"
# ------------------------------------------------------------------------------
# Sleep Stages
# ------------------------------------------------------------------------------
stages:
- W
- N1
- N2
- N3
- REM

View File

@@ -1,86 +0,0 @@
# ==============================================================================
# Sleep Stage Classification - Queue Random Baseline Experiment
# ==============================================================================
#
# This is an ABLATION STUDY baseline:
# - Queue structure is maintained (size=5)
# - BUT all 5 elements are refreshed with random samples EVERY step
# - No cumulative learning or retention of good examples
#
# Purpose: Test whether performance gains come from:
# 1. Queue structure itself (using 5 ICL examples)
# 2. Cumulative learning (retaining high-quality examples over time)
#
# Expected Result:
# - If Queue Random ≈ Confidence/Consistency → Queue structure is the key
# - If Queue Random < Confidence/Consistency → Cumulative learning is the key
# ==============================================================================
# ------------------------------------------------------------------------------
# Data Configuration
# ------------------------------------------------------------------------------
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
data_workers: 16
# ------------------------------------------------------------------------------
# Experiment Settings (to be overridden by CLI)
# ------------------------------------------------------------------------------
user_id: 5 # Target user (5, 10, or 15)
shuffle_seed: 42 # Shuffle seed (42 or 123)
# Queue settings (same as other policies for fair comparison)
queue_size: 5 # Number of ICL example sets
num_icl_shots: 5 # Number of ICL examples per agent
# Self-Consistency settings
num_sc_samples: 8 # Number of SC sampling agents (same as other policies)
# Example pool selection: "out" (different users) or "in" (same user)
example_pool: "out"
# Process all samples (no sampling)
sample_rate: 1
# Tracking window size for rolling window accuracy
tracking_window: 20
# Model context window size
num_ctx: 15000
# Temperature for LLM
temperature: 0.0
# Sleep stages for classification
stages:
- "W"
- "N1"
- "N2"
- "N3"
- "REM"
# ------------------------------------------------------------------------------
# Queue Policy
# ------------------------------------------------------------------------------
queue_policy: "queue_random"
# This config is for QUEUE RANDOM baseline:
# - Queue structure exists (5 slots)
# - ALL slots are refreshed with random samples every step
# - No retention of good examples
# ------------------------------------------------------------------------------
# Output Configuration
# ------------------------------------------------------------------------------
log_path: "/mnt/sting/ssum/sleepedf_sc_experiment/queue_random"
# ------------------------------------------------------------------------------
# Model Configuration (8 agents for Self-Consistency)
# ------------------------------------------------------------------------------
models:
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b

View File

@@ -1,85 +0,0 @@
# ==============================================================================
# Sleep-EDF Self-Consistency Experiment Configuration
# ==============================================================================
# ------------------------------------------------------------------------------
# Data Configuration
# ------------------------------------------------------------------------------
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
data_workers: 16
# ------------------------------------------------------------------------------
# Experiment Settings
# ------------------------------------------------------------------------------
num_seeds: 1
num_examples: 1
# Sample rate: process every Nth sample for faster experiments
sample_rate: 10
# Example pool selection: "out" (different users) or "in" (same user)
example_pool: "out"
# Continuous mode: if True, process samples in order; if False, shuffle
continuous: true
# Queue size for example selection (capacity of the example queue)
queue_size: 5
# Model context window size
num_ctx: 15000
# ------------------------------------------------------------------------------
# Selection Criteria
# Available options:
# - out_random: Random selection from different users (baseline)
# - in_random: Random selection from same user (personalization baseline)
# - out_similar: Chronos-2 embedding similarity-based selection
# - out_metadata: Gower distance-based selection (gender, age)
# ------------------------------------------------------------------------------
selection_criteria: "out_random"
# ------------------------------------------------------------------------------
# out_similar Configuration (Chronos-2 Embedding)
# Uncomment when using out_similar selection criteria
# ------------------------------------------------------------------------------
# embedding_path: "/home/ssum/tsllm_personalization_icl/embeddings_full"
# ------------------------------------------------------------------------------
# out_metadata Configuration (Gower Distance)
# Uncomment when using out_metadata selection criteria
# ------------------------------------------------------------------------------
# metadata_path: "/home/ssum/tsllm_personalization_icl/preprocess/SC-subjects.xls"
# weight_gender: 1.0 # Gender distance weight (same=0, different=1)
# weight_age: 1.0 # Age distance weight (normalized: |age1-age2|/range)
# ------------------------------------------------------------------------------
# Self-Consistency Settings
# ------------------------------------------------------------------------------
# Number of sampling iterations for Self-Consistency
num_sc_samples: 5
temperature: 0.0
# ------------------------------------------------------------------------------
# Model Configuration
# ------------------------------------------------------------------------------
# Multiple Ollama instances provide model diversity even with T=0
models:
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
# ------------------------------------------------------------------------------
# Output Configuration
# ------------------------------------------------------------------------------
log_path: "/mnt/sting/ssum/sleepedf_sc_result_test"
# Previous experiment result paths (for reference):
# chronos_base: "/mnt/sting/ssum/sleepedf_chronos_base_result"
# out_random: "/mnt/sting/ssum/sleepedf_chronos_result_outrandom"
# in_random: "/mnt/sting/ssum/sleepedf_chronos_result"

View File

@@ -1,147 +0,0 @@
import os
import re
import json
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
class Agent:
def __init__(
self,
name,
model_pool,
log_path,
):
self.name = name
self.model_pool = model_pool
self.log_path = log_path
self.root_log_path = log_path
self.agent_log_path = os.path.join(log_path, name)
os.makedirs(self.agent_log_path, exist_ok=True)
self.long_term_memory = []
self.short_term_memory = []
self.volatile_memory = []
def log(self, message, local=True):
path = os.path.join(self.root_log_path, "log.txt")
with open(path, "a", encoding="utf-8") as f:
message_type = "UNKNOWN"
if isinstance(message, SystemMessage):
message_type = "SYSTEM"
if isinstance(message, HumanMessage):
message_type = "HUMAN"
if isinstance(message, AIMessage):
message_type = "AI"
content = message.content.strip()
name = self.name
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
if local:
local_path = os.path.join(self.agent_log_path, "log.txt")
with open(local_path, "a", encoding="utf-8") as f:
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
def update_memory(self):
self.long_term_memory.extend(self.short_term_memory)
self.clean_short_term_memory()
self.clean_volatile_memory()
def clean_short_term_memory(self):
self.short_term_memory = []
def clean_volatile_memory(self):
self.volatile_memory = []
def clean_long_term_memory(self):
self.long_term_memory = []
def clean_json_text(self, text):
text = text.strip()
text = text.replace("", "'").replace("", "'")
text = text.replace("", "'").replace("", "'")
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text)
text = re.sub(r",\s*}", "}", text)
text = re.sub(r",\s*]", "]", text)
text = "".join(ch for ch in text if ch.isprintable())
text = text.replace("][", ",")
return text
def safe_parse_json(self, text):
if not text:
return None
text = text.strip()
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
elif not text.endswith("}"):
text += "}"
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
print("[!] JSON parse failed")
return None
async def validate_response(self, response, fields, volatile=False):
if (
not response
or not isinstance(response, dict)
or not all(field in response for field in fields)
):
print("[!] The JSON failed to be parsed. Trying again.")
content = (
"Failed to parse the JSON from the previous response. Please try again."
)
response = await self.invoke(content, volatile=volatile)
response = self.safe_parse_json(response)
if (
not response
or not isinstance(response, dict)
or not all(field in response for field in fields)
):
print("[!] Retry failed.")
return None
return response
def get_last_response(self):
if len(self.long_term_memory) >= 2:
last_msg = self.long_term_memory[-1]
if isinstance(last_msg, AIMessage):
return self.safe_parse_json(last_msg.content)
return None
def set_system_message(self, content, local=True):
system_message = SystemMessage(content=content)
self.log(system_message, local)
self.long_term_memory.append(system_message)
async def invoke(self, content, volatile=False, local=True):
messages = self.long_term_memory.copy()
if volatile:
messages.extend(self.volatile_memory)
else:
messages.extend(self.short_term_memory)
messages.append(HumanMessage(content=content))
try:
response = await self.model_pool.invoke(messages)
if volatile:
self.volatile_memory.extend([HumanMessage(content=content), response])
else:
self.short_term_memory.extend([HumanMessage(content=content), response])
local_ = not volatile and local
self.log(HumanMessage(content=content), local=local_)
self.log(response, local=local_)
return response.content.strip()
except Exception as e: # pylint: disable=broad-exception-caught
print(f"[Error] Error occurred while invoking LLM: {e}")

View File

@@ -1,137 +0,0 @@
import os
import asyncio
class AgentPool:
def __init__(self, log_path):
self.agents = {}
os.makedirs(log_path, exist_ok=True)
self.log_path = log_path
def add_agent(self, agent):
self.agents[agent.index] = agent
def log_summary(self, message, print_log=True):
path = os.path.join(self.log_path, "summary.txt")
with open(path, "a", encoding="utf-8") as f:
f.write(f"{message}\n")
if print_log:
print(message)
def get_last_responses(self):
responses = {}
for index, agent in self.agents.items():
response = agent.get_last_response()
if response:
responses[index] = response
return responses
def vote(self, responses, mode="majority_vote"):
"""
Vote for the final answer from agent responses.
Args:
responses: Dictionary mapping agent indices to response dicts
Format: {"agent_name": {"ANSWER": "...", "CONFIDENCE": 0.8, ...}, ...}
mode: Voting mode - "majority_vote", "highest_confidence", or "confidence_vote"
Returns:
The winning answer string
"""
if not responses:
return None
if mode == "highest_confidence":
# Find response with highest confidence
best_response = max(responses.values(), key=lambda x: x.get("CONFIDENCE", 0))
return best_response.get("ANSWER")
elif mode == "majority_vote":
# Count votes for each answer
answer_cnt = {}
for response in responses.values():
answer = response.get("ANSWER")
if answer:
answer_cnt[answer] = answer_cnt.get(answer, 0) + 1
if not answer_cnt:
return None
max_val = max(answer_cnt.values())
max_key = [k for k, v in answer_cnt.items() if v == max_val][0]
return max_key
elif mode == "confidence_vote":
# Weight votes by confidence
cnts = {}
for response in responses.values():
answer = response.get("ANSWER")
confidence = response.get("CONFIDENCE", 0)
if answer:
cnts[answer] = cnts.get(answer, 0) + confidence
if not cnts:
return None
return max(cnts, key=cnts.get)
else:
raise ValueError(f"Invalid mode: {mode}")
async def run_parallel_interpretation(self):
tasks = []
for _, agent in self.agents.items():
tasks.append(asyncio.create_task(agent.interpret()))
results = await asyncio.gather(*tasks)
print(results)
if None in results:
self.log_summary(f"[Error] Failed to interpret")
return None
for _, agent in self.agents.items():
agent.update_memory()
responses = self.get_last_responses()
for index, response in responses.items():
answer = response.get("ANSWER", "UNKNOWN")
confidence = response.get("CONFIDENCE", 0.0)
self.log_summary(f"[Interpretation] <{index}> provided answer: {answer} with confidence: {confidence}")
voted_result = self.vote(responses, mode="majority_vote")
queue_idcs = self.filter_examples(responses)
# Calculate confidence and consistency for the voted result
all_answers = [r.get("ANSWER") for r in responses.values()]
majority_responses = [r for r in responses.values() if r.get("ANSWER") == voted_result]
# Avg confidence of majority answer
avg_confidence = sum(r.get("CONFIDENCE", 0) for r in majority_responses) / len(majority_responses) if majority_responses else 0
# Consistency = ratio of majority votes
consistency = len(majority_responses) / len(responses) if responses else 0
return voted_result, queue_idcs, avg_confidence, consistency, responses
def filter_examples(self, responses):
"""
Filter examples based on highest confidence responses.
Args:
responses: Dictionary mapping agent indices to response dicts
Format: {"agent_name": {"ANSWER": "...", "CONFIDENCE": 0.8, "_example_idx": 0}, ...}
Returns:
List of queue indices with highest confidence
"""
queue_idcs = []
max_confidence = 0
for index, response in responses.items():
confidence = response.get("CONFIDENCE", 0)
if confidence > max_confidence:
max_confidence = confidence
queue_idcs = [index]
elif confidence == max_confidence:
queue_idcs.append(index)
# Debug: 선택 이유 출력
print(f"\n[Selection] Confidence Summary:")
for index, response in sorted(responses.items()):
conf = response.get("CONFIDENCE", 0)
ans = response.get("ANSWER", "?")
marker = " ★ SELECTED" if index in queue_idcs else ""
print(f" Case #{index}: {ans} (conf: {conf}){marker}")
print(f"[Selection] Max Confidence: {max_confidence} → Selected Case(s): {queue_idcs}")
return queue_idcs

View File

@@ -1,72 +0,0 @@
import os
import json
import datasets
import numpy as np
from glob import glob
from typing import Optional, TYPE_CHECKING
class DataLoader:
def __init__(
self,
data_path,
user_id,
example_pool="out",
continuous=True,
):
if not os.path.exists(os.path.join(data_path, "info.json")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
return
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
self.example_dataset = datasets.Dataset.from_list([])
users = glob(os.path.join(data_path, "*"))
users = [path.split("/")[-1] for path in users]
if "info.json" in users:
users.remove("info.json")
for user in users:
if example_pool == "out" and user == user_id:
continue
if example_pool == "in" and user != user_id:
continue
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
if not continuous:
self.test_dataset = self.test_dataset.shuffle(seed=0)
self.example_dataset = self.example_dataset.shuffle(seed=0)
def __len__(self):
return len(self.test_dataset)
def __getitem__(self, idx):
sample = self.test_dataset[idx]
return sample
def __iter__(self):
for sample in self.test_dataset:
yield sample
def get_examples(self):
return self.example_dataset
def get_metadata(self):
return self.metadata
def get_sensor_info(self):
return self.metadata["feature"]
def get_task_info(self):
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
classes_info = "\n".join(classes_info)
task_info += f"**Classes**:\n{classes_info}"
return task_info
def get_classes_info(self):
classes_info = [k for k in self.metadata["class"].keys()]
return classes_info

View File

@@ -1,162 +0,0 @@
"""
Self-Consistency Agent for Sleep Stage Classification
This module implements an agent that uses Self-Consistency methodology:
- Sample N times with the same prompt (configurable temperature)
- Output REASON, CONFIDENCE, ANSWER in JSON format
- Aggregate final answer via Majority Voting
The agent extends the base Agent class and adds Self-Consistency
specific functionality including multi-sampling and voting.
"""
import os
import sys
import json
import asyncio
from typing import List, Dict, Any, Optional
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from sc.core.agent import Agent
class JudgeAgent(Agent):
def __init__(
self,
name: str,
index: int,
model_pool,
task_info: str,
classes_info: List[str],
sensor_info: str,
sample: Dict[str, Any],
examples: List[Dict[str, Any]],
log_path: str,
):
super().__init__(
name=name + f"_{index}",
model_pool=model_pool,
log_path=log_path,
)
self.task_info = task_info
self.classes_info = classes_info
self.sensor_info = sensor_info
self.sample = sample
self.examples = examples
self.index = index # Store index for later retrieval
self._init_system_message()
def _init_system_message(self) -> None:
content = (
f"You are a {self.name} agent that judges the answers of other agents.\n"
f"You have the following information about the task:\n"
f"{self.task_info}\n\n"
f"You have the following information about the sensor data:\n"
f"{self.sensor_info}\n\n"
"Your goal is to analyze the features and "
"provide a reasoned answer with your confidence level."
)
self.set_system_message(content)
def _format_feature(self, value: Any) -> str:
"""
Format a feature value for display.
Uses scientific notation for very large or very small numbers,
and two decimal places for regular floats.
Args:
value: Feature value to format
Returns:
Formatted string representation
"""
if isinstance(value, float):
if abs(value) >= 1e4 or (abs(value) < 1e-2 and value != 0):
return f"{value:.2e}"
return f"{value:.2f}"
return str(value)
def _gen_example_info(self) -> str:
"""
Generate string representation of ICL examples.
Returns:
Formatted string with example features and labels,
or empty string if no examples provided
"""
if not self.examples or len(self.examples) == 0:
return ""
example_info = (
"**Examples**\n"
"Sensor values might not always align with your inherent "
"knowledge due to differences in data collection or processing. "
"So, we included a few labeled examples to help your interpretation:\n"
)
for example in self.examples:
example_info += f"*Example of {example['label']}*:\n"
for k, v in example["features"].items():
example_info += f" - {k}: {self._format_feature(v)}\n"
example_info += "\n"
return example_info.strip()
def _gen_feature_info(self) -> str:
"""
Generate string representation of current sample features.
Combines ICL example information with current sample features
to provide full context for classification.
Returns:
Formatted string with example and sample features
"""
feature_info = f"{self.name} features:\n"
# Add ICL example information if available
example_info = self._gen_example_info()
if example_info:
feature_info += f"{example_info}\n\n"
# Add current sample features
feature_info += "**Current sample features**:\n"
for k, v in self.sample["features"].items():
feature_info += f" - {k}: {self._format_feature(v)}\n"
return feature_info.strip()
async def interpret(self) -> str:
feature_info = self._gen_feature_info()
prompt = (
f"You have received sensor features from {self.name} modality:\n"
f"{feature_info}\n\n"
f"Please provide your answer for the task among {self.classes_info} "
"and the reasoning for your answer.\n"
"Also, please provide your confidence level for the answer as a float between 0.0 and 1.0.\n"
"Note that the sensor features might be wrong due to the data collection or processing.\n"
"You can evaluate the quality of the features by checking the examples you have.\n\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Your detailed reasoning for the classification>",\n'
' "CONFIDENCE": <Your confidence as a float between 0.0 and 1.0>,\n'
f' "ANSWER": "<Your answer among {self.classes_info}>"\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(prompt)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "CONFIDENCE", "ANSWER"]
)
# Add example index to response for queue update
if parsed_response:
parsed_response["_example_idx"] = self.index
return parsed_response

View File

@@ -1,180 +0,0 @@
"""
Majority Voting Module for Self-Consistency
Aggregates multiple LLM responses to determine the final answer via majority voting.
"""
from collections import Counter
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
import numpy as np
@dataclass
class VotingResult:
answer: str # Final answer (majority vote)
reason: str # Representative reasoning (from response with highest confidence)
avg_confidence: float # Average confidence (for majority answer)
consistency: float # Consistency score (ratio of majority votes)
vote_distribution: Dict[str, int] # Vote distribution {answer: count}
num_samples: int # Total number of valid samples
all_responses: List[Dict[str, Any]] = field(default_factory=list) # All responses
@property
def is_unanimous(self) -> bool:
"""Check if all votes are unanimous."""
return self.consistency == 1.0
@property
def majority_count(self) -> int:
"""Return the count of majority votes."""
return self.vote_distribution.get(self.answer, 0)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"answer": self.answer,
"reason": self.reason,
"avg_confidence": self.avg_confidence,
"consistency": self.consistency,
"vote_distribution": self.vote_distribution,
"num_samples": self.num_samples,
"is_unanimous": self.is_unanimous,
"majority_count": self.majority_count,
}
class MajorityVoting:
def __init__(self, valid_classes: Optional[List[str]] = None):
self.valid_classes = valid_classes
def aggregate(self, responses: List[Dict[str, Any]]) -> VotingResult:
valid_responses = self._filter_valid_responses(responses) # Filter valid responses only
if not valid_responses:
return self._empty_result()
answers = [r["ANSWER"] for r in valid_responses]
vote_counter = Counter(answers)
majority_answer = self._resolve_majority(vote_counter, valid_responses)
majority_count = vote_counter[majority_answer]
consistency = majority_count / len(valid_responses)
majority_responses = [
r for r in valid_responses if r["ANSWER"] == majority_answer
]
avg_confidence = np.mean([r["CONFIDENCE"] for r in majority_responses])
best_response = max(majority_responses, key=lambda x: x["CONFIDENCE"])
best_reason = best_response.get("REASON", "")
return VotingResult(
answer=majority_answer,
reason=best_reason,
avg_confidence=float(avg_confidence),
consistency=float(consistency),
vote_distribution=dict(vote_counter),
num_samples=len(valid_responses),
all_responses=valid_responses,
)
def _filter_valid_responses(
self,
responses: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
valid = [] # Valid responses
for r in responses:
if not isinstance(r, dict):
continue
if "ANSWER" not in r:
continue
# Check valid class
if self.valid_classes and r["ANSWER"] not in self.valid_classes:
continue
# Normalize CONFIDENCE
conf = r.get("CONFIDENCE", 0.5)
if isinstance(conf, str):
try:
conf = float(conf)
except ValueError:
conf = 0.5
r["CONFIDENCE"] = max(0.0, min(1.0, conf))
# Set default REASON
if "REASON" not in r:
r["REASON"] = ""
valid.append(r)
return valid
def _resolve_majority(
self,
vote_counter: Counter,
responses: List[Dict[str, Any]]
) -> str:
"""
Determine majority answer (use average confidence for tie-breaking).
"""
max_count = vote_counter.most_common(1)[0][1]
tied_answers = [a for a, c in vote_counter.items() if c == max_count]
if len(tied_answers) == 1:
return tied_answers[0]
# Tie-breaking: select answer with highest average confidence
best_answer = None
best_avg_conf = -1.0
for answer in tied_answers:
answer_responses = [r for r in responses if r["ANSWER"] == answer]
avg_conf = np.mean([r["CONFIDENCE"] for r in answer_responses])
if avg_conf > best_avg_conf:
best_avg_conf = avg_conf
best_answer = answer
return best_answer
def _empty_result(self) -> VotingResult:
"""Return empty result when no valid responses exist."""
return VotingResult(
answer="UNKNOWN",
reason="No valid responses received",
avg_confidence=0.0,
consistency=0.0,
vote_distribution={},
num_samples=0,
all_responses=[],
)
@staticmethod
def compute_agreement_matrix(
responses: List[Dict[str, Any]]
) -> np.ndarray:
"""
Compute pairwise agreement matrix between responses (for analysis).
Creates an N x N matrix where entry (i, j) is 1.0 if responses i and j
have the same ANSWER, and 0.0 otherwise. Useful for analyzing
consistency patterns across samples.
Args:
responses: List of response dictionaries with ANSWER keys
"""
n = len(responses)
matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
if responses[i].get("ANSWER") == responses[j].get("ANSWER"):
matrix[i, j] = 1.0
return matrix

View File

@@ -1,153 +0,0 @@
import os
import asyncio
import requests
from langchain_ollama import ChatOllama
from langchain_together import ChatTogether
from langchain_openai import ChatOpenAI
from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
def load_models(models, temperature=0.0, num_ctx=15000):
model_pool = AsyncModelPool()
for model in models:
model_pool.add_model(Model(model, temperature=temperature, num_ctx=num_ctx))
model_pool.init_models()
return model_pool
class Model:
def __init__(self, model, temperature, num_ctx):
if model.startswith("ollama:"):
model = model.replace("ollama:", "")
if "url:" in model: # custom parsing for local ollama models
model = model.replace("url:", "")
base_url = model.split("/")[0]
model_type = model.split("/")[1]
# self.model = ChatOllama(
# model=model_type,
# base_url=f"http://{base_url}",
# temperature=temperature,
# num_ctx=num_ctx,
# )
self.model = None
self.base_url = f"http://{base_url}/api/chat"
self.model_type = model_type
self.temperature = temperature
self.num_ctx = num_ctx
else:
self.model = ChatOllama(
model=model.replace("ollama:", ""),
temperature=temperature,
num_ctx=num_ctx,
)
elif model.startswith("together"):
if "TOGETHER_API_KEY" not in os.environ:
print("[!] TOGETHER_API_KEY is not set")
assert 0
self.model = ChatTogether(
model=model.replace("together:", ""),
temperature=temperature,
max_tokens=num_ctx,
max_retries=3,
)
elif model.startswith("openai"):
if "OPENAI_API_KEY" not in os.environ:
print("[!] OPENAI_API_KEY is not set")
assert 0
self.model = ChatOpenAI(
model=model.replace("openai:", ""),
temperature=temperature,
)
else:
self.model = init_chat_model(
model=model,
temperature=temperature,
)
def invoke(self, messages, logprobs=False, top_logprobs=0):
try:
if self.model:
response = self.model.invoke(messages)
return response
else:
converted_messages = []
for msg in messages:
role = msg.type
role = "user" if role == "human" else "assistant"
content = msg.content
converted_messages.append({"role": role, "content": content})
response = requests.post(self.base_url, json={
"model": self.model_type,
"messages": converted_messages,
"stream": False,
"options": {
"temperature": self.temperature,
"num_ctx": self.num_ctx,
},
"logprobs": logprobs,
"top_logprobs": top_logprobs,
})
response = response.json()
resp_msg = AIMessage(content=response["message"]["content"])
if logprobs:
return resp_msg, response["logprobs"]
else:
return resp_msg
return resp_msg, response["logprobs"]
except Exception as e:
print(f"[Error] Error occurred while invoking LLM: {e}")
return e
class AsyncModel:
def __init__(self, model):
self.model = model
async def invoke(self, content, logprobs=False, top_logprobs=0):
loop = asyncio.get_event_loop()
if logprobs:
response, logprobs = await loop.run_in_executor(
None,
lambda: self.model.invoke(content, logprobs=logprobs, top_logprobs=top_logprobs),
)
return response, logprobs
else:
response = await loop.run_in_executor(
None,
lambda: self.model.invoke(content),
)
return response
class AsyncModelPool:
def __init__(self):
self.models = []
self._available_models = None
self._model_semaphore = None
def add_model(self, model):
self.models.append(model)
def init_models(self):
print(f"Initializing {len(self.models)} models...")
self._available_models = asyncio.Queue()
for model in self.models:
async_model = AsyncModel(model)
self._available_models.put_nowait(async_model)
self._model_semaphore = asyncio.Semaphore(len(self.models))
async def invoke(self, content, logprobs=False, top_logprobs=0):
if self._available_models is None:
raise RuntimeError("Model pool not initialized. Call init_models() first.")
async_model = await self._available_models.get()
try:
if logprobs:
response, logprobs = await async_model.invoke(content, logprobs=logprobs, top_logprobs=top_logprobs)
return response, logprobs
else:
response = await async_model.invoke(content)
return response
finally:
self._available_models.put_nowait(async_model)

View File

@@ -1,188 +0,0 @@
"""
Model Utilities for Self-Consistency Experiments
This module provides model loading and management utilities with
configurable temperature support for Self-Consistency experiments.
- Temperature-configurable model loading
- Async model pool for parallel inference
- Support for Ollama and other LangChain-compatible models
"""
import asyncio
from typing import List, Any
from langchain_ollama import ChatOllama
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, BaseMessage
def load_models_with_temperature(
models: List[str],
temperature: float = 0.0
) -> "AsyncModelPoolSC":
"""
Load models with specified temperature into an async pool.
Creates a pool of models that can be used for parallel async inference.
Each model in the pool is initialized with the same temperature setting.
Args:
models: List of model specification strings.
Supported formats:
- "ollama:url:host:port/model_name" for remote Ollama
- "ollama:model_name" for local Ollama
- Standard LangChain model strings
temperature: LLM sampling temperature (default: 0.0)
"""
model_pool = AsyncModelPoolSC()
for model_str in models:
model_pool.add_model(ModelSC(model_str, temperature=temperature))
model_pool.init_models()
return model_pool
class ModelSC:
"""
Self-Consistency model wrapper with temperature support.
Wraps LangChain chat models with configurable temperature for
Self-Consistency experiments. Supports Ollama (local and remote)
and other LangChain-compatible models.
Attributes:
model: Underlying LangChain chat model
temperature: Configured sampling temperature
"""
def __init__(self, model: str, temperature: float = 0.0):
"""
Initialize model with specified temperature.
Args:
model: Model specification string
Formats:
- "ollama:url:host:port/model" - Remote Ollama instance
- "ollama:model" - Local Ollama instance
- Other strings - Passed to LangChain init_chat_model
temperature: Sampling temperature (0.0 = deterministic)
Raises:
ValueError: If model string format is invalid
"""
self.temperature = temperature
self._model_str = model
if model.startswith("ollama:"):
self._init_ollama_model(model, temperature)
else:
self._init_langchain_model(model, temperature)
def _init_ollama_model(self, model: str, temperature: float) -> None:
"""
Initialize Ollama model (local or remote).
Args:
model: Ollama model string (ollama:url:host:port/model or ollama:model)
temperature: Sampling temperature
"""
model = model.replace("ollama:", "")
if "url:" in model:
model = model.replace("url:", "")
parts = model.split("/")
base_url = parts[0]
if not base_url.startswith("http"):
base_url = "http://" + base_url
model_type = parts[1] if len(parts) > 1 else "llama2"
self.model = ChatOllama(
model=model_type,
base_url=base_url,
temperature=temperature,
num_ctx=12000,
)
else:
self.model = ChatOllama(
model=model,
temperature=temperature,
num_ctx=12000,
)
def _init_langchain_model(self, model: str, temperature: float) -> None:
self.model = init_chat_model(
model=model,
temperature=temperature,
)
def invoke(self, messages: List[BaseMessage]) -> Any:
response = self.model.invoke(messages)
return response
def __repr__(self) -> str:
"""String representation of the model."""
return f"ModelSC(model={self._model_str}, temperature={self.temperature})"
class AsyncModelSC:
def __init__(self, model: ModelSC):
self.model = model
async def invoke(self, messages: List[BaseMessage]) -> Any:
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: self.model.invoke(messages),
)
return response
class AsyncModelPoolSC:
def __init__(self):
self.models: List[ModelSC] = []
self._available_models: asyncio.Queue = None
self._model_semaphore: asyncio.Semaphore = None
def add_model(self, model: ModelSC) -> None:
self.models.append(model)
def init_models(self) -> None:
if not self.models:
raise RuntimeError("No models added. Call add_model() first.")
self._available_models = asyncio.Queue()
for model in self.models:
async_model = AsyncModelSC(model)
self._available_models.put_nowait(async_model)
self._model_semaphore = asyncio.Semaphore(len(self.models))
def warmup(self) -> None:
print(f"[ModelPool] Warming up {len(self.models)} models...")
for i, model in enumerate(self.models):
try:
model.invoke([HumanMessage(content="Hello world!")])
print(f"[ModelPool] Model {i+1}/{len(self.models)} warmed up")
except Exception as e:
print(f"[ModelPool] Model {i+1} warmup failed: {e}")
print("[ModelPool] All models warmed up")
async def invoke(self, messages: List[BaseMessage]) -> Any:
if self._available_models is None:
raise RuntimeError("Model pool not initialized. Call init_models() first.")
async_model = await self._available_models.get()
try:
response = await async_model.invoke(messages)
return response
finally:
# Always return model to pool
self._available_models.put_nowait(async_model)
@property
def size(self) -> int:
return len(self.models)
def __repr__(self) -> str:
return f"AsyncModelPoolSC(size={self.size})"

View File

@@ -1,154 +0,0 @@
from collections import deque
import random
class Queue:
def __init__(self, dataset_size, capacity=5):
self.dataset_size = dataset_size
self.classes = list(dataset_size.keys())
initial_cases = [
[random.choice(dataset_size[cls]) for cls in self.classes]
for _ in range(capacity)
]
self._queue = deque(initial_cases, maxlen=capacity)
self._input_time = {tuple(case): 0 for case in initial_cases}
self._current_time = 0
self._eviction_history = []
self._usage_count = {tuple(case): 0 for case in initial_cases}
def set_current_time(self, sample_idx: int):
self._current_time = sample_idx
def push(self, index):
if list(index) in [list(c) for c in self._queue]:
return None
evicted = None
if len(self._queue) == self._queue.maxlen:
evicted = self._queue.popleft()
self._record_eviction_stats(evicted)
self._queue.append(list(index))
self._register_stats(index)
return evicted
def pop(self):
return self._queue.popleft() if self._queue else None
def __iter__(self):
"""
Make Queue iterable by yielding all elements in the queue.
This allows the queue to be used in for loops:
for ex_index in ex_queue:
# process ex_index
"""
for idx in self._queue:
yield idx
def _register_stats(self, case):
key = tuple(case)
self._input_time[key] = self._current_time
self._usage_count[key] = 0
def _record_eviction_stats(self, case):
key = tuple(case)
if key in self._input_time:
duration = self._current_time - self._input_time[key]
self._eviction_history.append({
"case": case,
"duration": duration,
"usage": self._usage_count.get(key, 0),
"evicted_at": self._current_time
})
del self._input_time[key]
if key in self._usage_count:
del self._usage_count[key]
def update_by_confidence(self, confidence_map):
if not confidence_map:
return None
current_items = list(self._queue)
items_with_score = []
for i, item in enumerate(current_items):
items_with_score.append({
"item": item,
"score": confidence_map.get(i, -1.0)
})
items_with_score.sort(key=lambda x: x["score"], reverse=True)
self._queue = deque([x["item"] for x in items_with_score], maxlen=self._queue.maxlen)
evicted = None
if len(self._queue) == self._queue.maxlen:
evicted = self._queue.pop()
self._record_eviction_stats(evicted)
new_case = [random.choice(self.dataset_size[cls]) for cls in self.classes]
self._queue.append(new_case)
self._register_stats(new_case)
return evicted
def update_by_consistency(self, consistency_map):
"""Update queue based on consistency scores (agreement ratio among agents)."""
if not consistency_map:
return None
current_items = list(self._queue)
items_with_score = []
for i, item in enumerate(current_items):
items_with_score.append({
"item": item,
"score": consistency_map.get(i, -1.0)
})
items_with_score.sort(key=lambda x: x["score"], reverse=True)
self._queue = deque([x["item"] for x in items_with_score], maxlen=self._queue.maxlen)
evicted = None
if len(self._queue) == self._queue.maxlen:
evicted = self._queue.pop()
self._record_eviction_stats(evicted)
new_case = [random.choice(self.dataset_size[cls]) for cls in self.classes]
self._queue.append(new_case)
self._register_stats(new_case)
return evicted
def increment_usage(self, queue_indices):
"""Increment usage count for specified queue indices."""
for idx in queue_indices:
if 0 <= idx < len(self._queue):
key = tuple(self._queue[idx])
self._usage_count[key] = self._usage_count.get(key, 0) + 1
def get_instance_id(self):
return id(self)
def get_state_with_stats(self):
result = []
for case in self._queue:
key = tuple(case)
age = self._current_time - self._input_time.get(key, self._current_time)
result.append({
"case": case,
"age": age,
"usage": self._usage_count.get(key, 0)
})
return result
def get_survival_stats(self):
return self._eviction_history
def get_survival_summary(self):
if not self._eviction_history:
return {
"total_evicted": 0,
"avg_survival": 0,
"max_survival": 0,
"avg_usage": 0
}
durations = [x["duration"] for x in self._eviction_history]
usages = [x["usage"] for x in self._eviction_history]
return {
"total_evicted": len(durations),
"avg_survival": sum(durations) / len(durations),
"max_survival": max(durations),
"avg_usage": sum(usages) / len(usages) if usages else 0
}

View File

@@ -1,193 +0,0 @@
"""
Self-Consistency Agent for Sleep Stage Classification
This module implements an agent that uses Self-Consistency methodology:
- Sample N times with the same prompt (configurable temperature)
- Output REASON, CONFIDENCE, ANSWER in JSON format
- Aggregate final answer via Majority Voting
The agent extends the base Agent class and adds Self-Consistency
specific functionality including multi-sampling and voting.
"""
import os
import sys
import json
import asyncio
from typing import List, Dict, Any, Optional
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from sc.core.agent import Agent
class SCAgent(Agent):
"""
Self-Consistency based sleep stage classification agent.
This agent samples N times for the same input and uses majority voting
to determine the final answer. It supports confidence-based tie-breaking
and provides detailed metrics about consistency.
Attributes:
task_info: Description of the classification task
classes_info: List of valid class labels
sensor_info: Information about sensor data
sample: Current sample being classified
examples: ICL examples for context
voter: MajorityVoting instance for aggregation
"""
def __init__(
self,
name: str,
index: int,
model_pool,
task_info: str,
classes_info: List[str],
sensor_info: str,
sample: Dict[str, Any],
examples: List[Dict[str, Any]],
log_path: str,
):
"""
Initialize Self-Consistency Agent.
Args:
name: Agent identifier name
index: Index of the agent in the dataset
model_pool: Async model pool for LLM inference
task_info: Description of the classification task
classes_info: List of valid class labels
sensor_info: Information about sensor data format
sample: Sample to be classified
examples: ICL examples for few-shot learning
log_path: Directory path for saving logs
"""
super().__init__(
name=name + f"_{index}",
model_pool=model_pool,
log_path=log_path,
)
self.task_info = task_info
self.classes_info = classes_info
self.sensor_info = sensor_info
self.sample = sample
self.examples = examples
self.index = index # Store index for later retrieval
self._init_system_message()
def _init_system_message(self) -> None:
"""Initialize the system message that defines the agent's role."""
content = (
f"You are a {self.name} agent that interprets sensor data to solve a task.\n"
f"You have the following information about the task:\n"
f"{self.task_info}\n\n"
f"You have the following information about the sensor data:\n"
f"{self.sensor_info}\n\n"
"Your goal is to analyze the features and "
"provide a reasoned answer with your confidence level."
)
self.set_system_message(content)
def _format_feature(self, value: Any) -> str:
"""
Format a feature value for display.
Uses scientific notation for very large or very small numbers,
and two decimal places for regular floats.
Args:
value: Feature value to format
Returns:
Formatted string representation
"""
if isinstance(value, float):
if abs(value) >= 1e4 or (abs(value) < 1e-2 and value != 0):
return f"{value:.2e}"
return f"{value:.2f}"
return str(value)
def _gen_example_info(self) -> str:
"""
Generate string representation of ICL examples.
Returns:
Formatted string with example features and labels,
or empty string if no examples provided
"""
if not self.examples or len(self.examples) == 0:
return ""
example_info = (
"**Examples**\n"
"Sensor values might not always align with your inherent "
"knowledge due to differences in data collection or processing. "
"So, we included a few labeled examples to help your interpretation:\n"
)
for example in self.examples:
example_info += f"*Example of {example['label']}*:\n"
for k, v in example["features"].items():
example_info += f" - {k}: {self._format_feature(v)}\n"
example_info += "\n"
return example_info.strip()
def _gen_feature_info(self) -> str:
"""
Generate string representation of current sample features.
Combines ICL example information with current sample features
to provide full context for classification.
Returns:
Formatted string with example and sample features
"""
feature_info = f"{self.name} features:\n"
# Add ICL example information if available
example_info = self._gen_example_info()
if example_info:
feature_info += f"{example_info}\n\n"
# Add current sample features
feature_info += "**Current sample features**:\n"
for k, v in self.sample["features"].items():
feature_info += f" - {k}: {self._format_feature(v)}\n"
return feature_info.strip()
async def interpret(self) -> str:
feature_info = self._gen_feature_info()
prompt = (
f"You have received sensor features from {self.name} modality:\n"
f"{feature_info}\n\n"
f"Please provide your answer for the task among {self.classes_info} "
"and the reasoning for your answer.\n"
"Also, please provide your confidence level for the answer as a float between 0.0 and 1.0.\n"
"Note that the sensor features might be wrong due to the data collection or processing.\n"
"You can evaluate the quality of the features by checking the examples you have.\n\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Your detailed reasoning for the classification>",\n'
' "CONFIDENCE": <Your confidence as a float between 0.0 and 1.0>,\n'
f' "ANSWER": "<Your answer among {self.classes_info}>"\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(prompt)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "CONFIDENCE", "ANSWER"]
)
# Add example index to response for queue update
if parsed_response:
parsed_response["_example_idx"] = self.index
return parsed_response

View File

@@ -1,426 +0,0 @@
def _enabled(config=None):
if config is None:
return True
return config.get("debug", True)
def log(message, config=None):
if not _enabled(config):
return
print(message)
def warn_no_examples():
print("[WARN] No examples found for dataloader. Skipping task.")
def log_queue_state_before(processed_count, user_info, ex_queue, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print(
f"[Sample {processed_count}] {user_info} - Queue State BEFORE Processing "
f"(Instance ID: {ex_queue.get_instance_id()}):"
)
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
print(
f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})"
)
print(f"{'#'*60}\n")
def warn_no_agents(processed_count):
print(f"[WARN] No agents added for sample {processed_count}. Skipping.")
def warn_interpretation_failed(processed_count):
print(f"[WARN] Interpretation failed for sample {processed_count}. Skipping.")
def log_tracking(
processed_count,
user_info,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
avg_confidence_so_far,
config=None,
):
if not _enabled(config):
return
print(f"\n{'='*60}")
print(f"[TRACKING] Sample {processed_count} | {user_info}")
print(
f" Answer: {answer} | GT: {ground_truth} | "
f"{'✓ CORRECT' if is_correct else '✗ WRONG'}"
)
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
print(" ─────────────────────────────────────────────────────")
print(
f" Cumulative Accuracy: {cumulative_accuracy:.4f} "
f"({cumulative_correct}/{processed_count + 1})"
)
print(
f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}"
)
print(f" Avg Confidence (so far): {avg_confidence_so_far:.4f}")
print(f"{'='*60}\n")
def log_queue_state_after(processed_count, ex_queue, config=None):
if not _enabled(config):
return
print(f"\n[Sample {processed_count}] Queue State AFTER Update:")
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
print(
f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})"
)
print()
def warn_no_responses(processed_count):
print(f"[WARN] No responses returned, falling back to basic update.")
def log_final_queue_stats(user_info, survival_summary, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print(f"[FINAL] Queue Survival Statistics for {user_info}")
print(f" Total Evicted Cases: {survival_summary['total_evicted']}")
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
print(f" Max Survival: {survival_summary['max_survival']} samples")
print(f" Min Survival: {survival_summary['min_survival']} samples")
print(f" Avg Usage Count: {survival_summary['avg_usage']:.2f}")
print(f" Max Usage Count: {survival_summary['max_usage']}")
print(f"{'#'*60}\n")
def warn_no_user_dirs(data_path):
print(f"[WARN] No user directories found in {data_path}")
def log_found_users(users, config=None):
if not _enabled(config):
return
print(
f"[INFO] Found {len(users)} users: {users[:5]}"
f"{'...' if len(users) > 5 else ''}"
)
def warn_skip_user_no_test_data(user):
print(f"[WARN] Skipping user {user} - no test data available")
def warn_skip_user_no_example_data(user):
print(f"[WARN] Skipping user {user} - no example data available")
def log_main_loading_config(config_path, config=None):
if not _enabled(config):
return
print(f"[MAIN] Loading config: {config_path}")
def log_main_config(config, config_enabled=True):
if not config_enabled:
return
print("=" * 60)
print("SELF-CONSISTENCY EXPERIMENT CONFIGURATION")
print("=" * 60)
print(f" Data path: {config.get('data_path', 'N/A')}")
print(f" Log path: {config.get('log_path', 'N/A')}")
print(f" Num ICL examples: {config.get('num_examples', 1)}")
print(f" Num seeds: {config.get('num_seeds', 1)}")
print(f" Num SC samples: {config.get('num_sc_samples', 5)}")
print(f" Temperature: {config.get('temperature', 0.0)}")
print(f" Sample rate: 1/{config.get('sample_rate', 10)}")
print(f" Num models: {len(config.get('models', []))}")
print("=" * 60)
def log_main_start(config=None):
if not _enabled(config):
return
print("[MAIN] Starting experiments...")
def error_task_failed(result):
print(f"[ERROR] Task failed with exception: {result}")
def log_total_results(count, config=None):
if not _enabled(config):
return
print(f"[MAIN] Total results collected: {count}")
def log_experiment_results(stats, config=None):
if not _enabled(config):
return
print("\n" + "=" * 60)
print("EXPERIMENT RESULTS")
print("=" * 60)
print(f" Total samples: {stats.get('total_samples', 0)}")
print(f" Accuracy: {stats.get('accuracy', 0):.4f}")
print(f" Avg Confidence: {stats.get('avg_confidence', 0):.4f}")
print(f" Avg Consistency: {stats.get('avg_consistency', 0):.4f}")
print(
" High Consistency (>=0.8) Accuracy: "
f"{stats.get('high_consistency_accuracy', 0):.4f}"
)
print(f" High Consistency Samples: {stats.get('high_consistency_samples', 0)}")
print("\n Class-wise Accuracy:")
for cls, acc in stats.get("class_accuracy", {}).items():
print(f" {cls}: {acc:.4f}")
def log_temporal_analysis(temporal, config=None):
if not _enabled(config):
return
if not temporal:
return
print("\n" + "-" * 60)
print(" TEMPORAL ANALYSIS (Caching Effect)")
print("-" * 60)
print(f" First Half Accuracy: {temporal.get('first_half_accuracy', 0):.4f}")
print(f" Second Half Accuracy: {temporal.get('second_half_accuracy', 0):.4f}")
improvement = temporal.get("accuracy_improvement", 0)
improvement_sign = "+" if improvement >= 0 else ""
print(f" Improvement: {improvement_sign}{improvement:.4f}")
quartiles = temporal.get("quartile_accuracies", [])
if quartiles:
print(
f" Quartile Accuracies: Q1={quartiles[0]:.4f}"
+ (f", Q2={quartiles[1]:.4f}" if len(quartiles) > 1 else "")
+ (f", Q3={quartiles[2]:.4f}" if len(quartiles) > 2 else "")
+ (f", Q4={quartiles[3]:.4f}" if len(quartiles) > 3 else "")
)
print(f"\n First Half Confidence: {temporal.get('first_half_confidence', 0):.4f}")
print(f" Second Half Confidence: {temporal.get('second_half_confidence', 0):.4f}")
conf_improvement = temporal.get("confidence_improvement", 0)
conf_sign = "+" if conf_improvement >= 0 else ""
print(f" Confidence Change: {conf_sign}{conf_improvement:.4f}")
def log_queue_stats(queue_stats, config=None):
if not _enabled(config):
return
if not queue_stats:
return
print("\n" + "-" * 60)
print(" QUEUE SURVIVAL STATISTICS")
print("-" * 60)
print(f" Total Evicted Cases: {queue_stats.get('total_evicted', 0)}")
print(f" Avg Survival: {queue_stats.get('avg_survival', 0):.2f} samples")
print(f" Max Survival: {queue_stats.get('max_survival', 0)} samples")
print(f" Min Survival: {queue_stats.get('min_survival', 0)} samples")
print(f" Avg Usage Count: {queue_stats.get('avg_usage', 0):.2f}")
print(f" Max Usage Count: {queue_stats.get('max_usage', 0)}")
print("=" * 60)
def log_save_statistics(stats_path, config=None):
if not _enabled(config):
return
print(f"[SAVE] Statistics saved to: {stats_path}")
def log_save_results(results_path, config=None):
if not _enabled(config):
return
print(f"[SAVE] Results saved to: {results_path}")
def log_save_config(config_path, config=None):
if not _enabled(config):
return
print(f"[SAVE] Config saved to: {config_path}")
def log_results_saved(log_path, config=None):
if not _enabled(config):
return
print(f"[MAIN] Results saved to: {log_path}")
def log_policy_experiment_header(policy_label, user_id, shuffle_seed, total_samples, queue_size, config=None, policy_note=None):
if not _enabled(config):
return
print(f"\n{'='*80}")
print(f"{policy_label} QUEUE POLICY EXPERIMENT")
print(f"User: {user_id} | Shuffle Seed: {shuffle_seed}")
print(f"Total samples: {total_samples} | Queue size: {queue_size}")
if policy_note:
print(policy_note)
print(f"{'='*80}\n")
def log_policy_queue_state(policy_label, processed_count, user_id, ex_queue, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print(f"[Sample {processed_count}] User {user_id} | {policy_label} Policy")
print("Queue State BEFORE Processing:")
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
print(f"{'#'*60}\n")
def log_policy_result(policy_label, processed_count, answer, ground_truth, is_correct, avg_confidence, consistency, cumulative_accuracy, cumulative_correct, window_accuracy, recent_results, config=None):
if not _enabled(config):
return
print(f"\n{'='*60}")
print(f"[RESULT] Sample {processed_count} | {policy_label} Policy")
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
print(f"{'='*60}\n")
def log_confidence_map(responses, confidence_map, config=None):
if not _enabled(config):
return
print("\n[CONFIDENCE MAP] Per-agent confidence scores:")
for idx, conf in sorted(confidence_map.items()):
ans = responses[idx].get("ANSWER", "?")
print(f" Queue[{idx}]: answer={ans}, confidence={conf:.4f}")
def log_consistency_map(responses, consistency_map, config=None):
if not _enabled(config):
return
print("\n[CONSISTENCY MAP] Per-agent consistency scores:")
for idx, cons in sorted(consistency_map.items()):
ans = responses[idx].get("ANSWER", "?")
print(f" Queue[{idx}]: answer={ans}, consistency={cons:.4f}")
def log_policy_queue_state_after(policy_label, processed_count, ex_queue, config=None):
if not _enabled(config):
return
print(f"\n[Sample {processed_count}] Queue State AFTER {policy_label} Update:")
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
def log_final_policy_survival(policy_label, user_id, shuffle_seed, survival_summary, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print(f"[FINAL] Queue Survival Statistics - {policy_label} Policy")
print(f" User: {user_id} | Seed: {shuffle_seed}")
print(f" Total Evicted: {survival_summary['total_evicted']}")
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
print(f" Avg Usage: {survival_summary['avg_usage']:.2f}")
print(f"{'#'*60}\n")
def log_policy_main_header(policy_label, user_id, shuffle_seed, queue_size, num_models, log_path, config=None, policy_note=None):
if not _enabled(config):
return
print("=" * 80)
print(f"{policy_label} QUEUE POLICY EXPERIMENT")
print("=" * 80)
print(f" User ID: {user_id}")
print(f" Shuffle Seed: {shuffle_seed}")
print(f" Queue Size: {queue_size}")
print(f" SC Samples (Agents): {num_models}")
print(f" Log Path: {log_path}")
if policy_note:
print(policy_note)
print("=" * 80)
def log_policy_loading_data(user_id, shuffle_seed, config=None):
if not _enabled(config):
return
print(f"\n[MAIN] Loading shuffled data for user {user_id}, seed {shuffle_seed}...")
def log_policy_loading_models(config=None):
if not _enabled(config):
return
print("\n[MAIN] Loading models...")
def log_policy_start(config=None, label="experiment"):
if not _enabled(config):
return
print(f"\n[MAIN] Starting {label}...")
def log_policy_complete_summary(policy_label, user_id, shuffle_seed, stats, stage_accuracy, stage_counts, temporal, config=None, expected_note=None):
if not _enabled(config):
return
print("\n" + "=" * 80)
print(f"EXPERIMENT COMPLETE - {policy_label} POLICY")
print("=" * 80)
print(f" User: {user_id} | Seed: {shuffle_seed}")
print(f" Total Samples: {stats.get('total_samples', 0)}")
print(f" Overall Accuracy: {stats.get('accuracy', 0):.4f}")
print(f" Macro F1: {stats.get('macro_f1', 0):.4f}")
print("\n Per-Stage Accuracy:")
for stage, acc in stage_accuracy.items():
count = stage_counts.get(stage, 0)
print(f" {stage}: {acc:.4f} (n={count})")
print("\n Temporal Analysis:")
print(f" First Half: {temporal.get('first_half_accuracy', 0):.4f}")
print(f" Second Half: {temporal.get('second_half_accuracy', 0):.4f}")
print(f" Improvement: {temporal.get('improvement', 0):+.4f}")
if expected_note:
print(expected_note)
print("=" * 80)
def log_queue_random_stats(user_id, shuffle_seed, sampler_stats, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print("[FINAL] Queue Random Statistics")
print(f" User: {user_id} | Seed: {shuffle_seed}")
print(f" Total Steps: {sampler_stats['total_steps']}")
print(f" Total Refreshed: {sampler_stats['total_refreshed']} example sets")
print(f" Avg Refresh per Step: {sampler_stats['avg_refresh_per_step']} (always full)")
print(f"{'#'*60}\n")
def log_queue_random_sampler_init(example_count, classes, queue_size, config=None):
if not _enabled(config):
return
print(f"[QueueRandomSampler] Initialized with {example_count} examples")
print(f" Classes: {classes}")
print(f" Queue size: {queue_size}")
print(" Policy: ALL elements refreshed every step")
def log_queue_random_queue_state(processed_count, user_id, queue_sampler, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print(f"[Sample {processed_count}] User {user_id} | QUEUE RANDOM Policy")
print("Queue State (ALL FRESH RANDOM samples):")
for idx, ex_idcs in enumerate(queue_sampler):
print(f" [{idx}] Example indices: {ex_idcs}")
print(f"{'#'*60}\n")
def log_queue_random_result(processed_count, answer, ground_truth, is_correct, avg_confidence, consistency, cumulative_accuracy, cumulative_correct, window_accuracy, recent_results, config=None):
if not _enabled(config):
return
print(f"\n{'='*60}")
print(f"[RESULT] Sample {processed_count} | QUEUE RANDOM Policy")
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
print(" [NOTE] Queue will be FULLY REFRESHED for next sample")
print(f"{'='*60}\n")

View File

@@ -1,222 +0,0 @@
"""
HTTP API for HuggingFace causal LM inference (answer + logits).
This module exposes a small FastAPI service that:
- Loads a local HuggingFace CausalLM once at startup
- Accepts chat-style messages via HTTP
- Returns generated text
- Returns *logits-derived* information in a practical size:
- top-k logprobs for each generated token (recommended)
- optional prompt logits top-k for the final prompt position
Why not return full logits?
Full logits are extremely large (seq_len x vocab_size) and will quickly
overwhelm network/memory. This API defaults to returning top-k logprobs.
Run:
MODEL_DIR=/path/to/model \\
uvicorn sc.hf_api:app --host 0.0.0.0 --port 8000
Example:
curl -X POST http://localhost:8000/generate \\
-H 'Content-Type: application/json' \\
-d '{"messages":[{"role":"user","content":"Hello!"}],"max_new_tokens":64,"top_k":10}'
"""
from __future__ import annotations
import os
from typing import Any, Dict, List, Literal, Optional
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
DEFAULT_MODEL_DIR = os.environ.get(
"MODEL_DIR", "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
)
app = FastAPI(title="HF LLM API", version="0.1.0")
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
class GenerateRequest(BaseModel):
messages: List[ChatMessage]
max_new_tokens: int = Field(default=128, ge=1, le=1024)
temperature: float = Field(default=0.0, ge=0.0, le=2.0)
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
do_sample: bool = False
top_k: int = Field(default=20, ge=1, le=200)
# If True, also returns prompt last-position top-k logits (not full matrix)
include_prompt_topk: bool = False
class TokenTopK(BaseModel):
token_id: int
token: str
logprob: float
class GeneratedStep(BaseModel):
token_id: int
token: str
logprob: float
topk: List[TokenTopK]
class PromptTopK(BaseModel):
position: int
topk: List[TokenTopK]
class GenerateResponse(BaseModel):
prompt: str
generated_text: str
generated_token_ids: List[int]
steps: List[GeneratedStep]
prompt_topk: Optional[PromptTopK] = None
def _get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def _load_model_and_tokenizer(model_dir: str):
device = _get_device()
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
).eval()
return tokenizer, model, device
@app.on_event("startup")
def _startup_load():
# Load once; shared across requests.
global tokenizer, model, device
tokenizer, model, device = _load_model_and_tokenizer(DEFAULT_MODEL_DIR)
app.state.model_dir = DEFAULT_MODEL_DIR
app.state.device = device
@app.get("/health")
def health() -> Dict[str, Any]:
return {
"ok": True,
"model_dir": getattr(app.state, "model_dir", None),
"device": getattr(app.state, "device", None),
}
def _apply_chat_template(messages: List[ChatMessage]) -> str:
# Convert pydantic objects to plain dicts compatible with HF template.
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
try:
prompt = tokenizer.apply_chat_template(
msg_dicts,
tokenize=False,
add_generation_prompt=True,
)
except Exception:
# Fallback: naive concatenation
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in msg_dicts]) + "\nassistant:"
return prompt
def _topk_from_logits(logits_1d: torch.Tensor, top_k: int) -> List[TokenTopK]:
# logits_1d: (vocab,)
top_vals, top_ids = torch.topk(logits_1d, k=top_k)
# Convert to logprobs for interpretability
logprobs = torch.log_softmax(logits_1d, dim=-1)
out: List[TokenTopK] = []
for tid in top_ids.tolist():
tok = tokenizer.decode([tid])
out.append(
TokenTopK(
token_id=int(tid),
token=tok,
logprob=float(logprobs[tid].detach().cpu().item()),
)
)
# Sort in descending logprob (topk preserves order, but be explicit)
out.sort(key=lambda x: x.logprob, reverse=True)
return out
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest) -> GenerateResponse:
if not hasattr(app.state, "model_dir"):
raise HTTPException(status_code=503, detail="Model not loaded yet")
prompt = _apply_chat_template(req.messages)
inputs = tokenizer(prompt, return_tensors="pt")
if device == "cuda":
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Use generate() so we can get per-step scores (logits).
# output.scores is a list with length = generated_tokens
# each element shape: (batch, vocab)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=req.do_sample if req.temperature > 0 else False,
temperature=req.temperature if req.temperature > 0 else 1.0,
top_p=req.top_p,
return_dict_in_generate=True,
output_scores=True,
)
# Generated token ids include prompt + new tokens
seq = out.sequences[0]
prompt_len = int(inputs["input_ids"].shape[1])
gen_token_ids = seq[prompt_len:].tolist()
generated_text = tokenizer.decode(gen_token_ids, skip_special_tokens=True)
# Build per-step top-k + chosen token logprob
steps: List[GeneratedStep] = []
if out.scores is None:
raise HTTPException(status_code=500, detail="Model did not return scores")
for step_idx, step_logits in enumerate(out.scores):
# step_logits: (1, vocab)
step_logits_1d = step_logits[0]
chosen_id = int(gen_token_ids[step_idx]) if step_idx < len(gen_token_ids) else None
logprobs_1d = torch.log_softmax(step_logits_1d, dim=-1)
chosen_logprob = float(logprobs_1d[chosen_id].detach().cpu().item()) if chosen_id is not None else float("nan")
steps.append(
GeneratedStep(
token_id=chosen_id,
token=tokenizer.decode([chosen_id]),
logprob=chosen_logprob,
topk=_topk_from_logits(step_logits_1d, req.top_k),
)
)
prompt_topk: Optional[PromptTopK] = None
if req.include_prompt_topk:
with torch.no_grad():
forward = model(**inputs)
# forward.logits: (1, seq_len, vocab)
last_pos = int(forward.logits.shape[1] - 1)
last_logits = forward.logits[0, -1, :]
prompt_topk = PromptTopK(position=last_pos, topk=_topk_from_logits(last_logits, req.top_k))
return GenerateResponse(
prompt=prompt,
generated_text=generated_text,
generated_token_ids=[int(t) for t in gen_token_ids],
steps=steps,
prompt_topk=prompt_topk,
)

View File

@@ -1,31 +0,0 @@
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_dir = "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
).eval()
messages = [
{"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
]
# Convert chat messages -> a single prompt string using the model's chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out = model(**inputs)
logits = out.logits # shape: (batch=1, seq_len, vocab_size)
print("logits shape:", logits.shape)

View File

@@ -1 +0,0 @@
# SC Preprocess Module

View File

@@ -1,409 +0,0 @@
"""
Data Shuffling Module for Sleep Stage Classification Experiment
This module provides functionality to shuffle user data for fair comparison
across different Queue policies (Confidence, Consistency, Random).
Features:
- Shuffle user data with fixed random seed for reproducibility
- Preserve original indices for tracking
- Save shuffled data to JSON for reuse
- Ensure all 3 experiments use identical shuffled order
Usage:
from sc.preprocess.shuffle_data import shuffle_user_data, load_shuffled_data
# Shuffle and save
shuffled_data = shuffle_user_data(user_id=5, seed=42, data_path="...")
# Load existing shuffled data
shuffled_data = load_shuffled_data(user_id=5, seed=42, output_dir="...")
"""
import os
import sys
import json
import random
import numpy as np
from typing import List, Dict, Any, Optional
from datetime import datetime
import datasets
from glob import glob
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
def shuffle_user_data(
user_id: int,
seed: int,
data_path: str,
output_dir: str = None,
save_to_file: bool = True,
) -> List[Dict[str, Any]]:
"""
Shuffle user data and optionally save to JSON file.
Args:
user_id: User ID to process (e.g., 5 or 10)
seed: Random seed for reproducibility
data_path: Path to SleepEDF data directory
output_dir: Directory to save shuffled data (default: data_path/shuffled)
save_to_file: Whether to save shuffled data to JSON
Returns:
List of shuffled samples with original_idx preserved
"""
# Set random seeds for reproducibility
random.seed(seed)
np.random.seed(seed)
# Format user_id with leading zeros (e.g., "05", "10")
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
# Load user test data
test_path = os.path.join(data_path, user_str, "2")
if not os.path.exists(test_path):
raise FileNotFoundError(f"Test data not found: {test_path}")
test_dataset = datasets.load_from_disk(test_path)
# Create list with original indices
samples_with_idx = []
for idx, sample in enumerate(test_dataset):
sample_dict = {
"original_idx": idx,
"label": sample["label"],
"features": sample["features"],
}
samples_with_idx.append(sample_dict)
# Shuffle the samples
shuffled_samples = samples_with_idx.copy()
random.shuffle(shuffled_samples)
# Add shuffled index
for shuffled_idx, sample in enumerate(shuffled_samples):
sample["shuffled_idx"] = shuffled_idx
# Save to file if requested
if save_to_file:
if output_dir is None:
output_dir = os.path.join(data_path, "shuffled")
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
metadata = {
"user_id": user_str,
"seed": seed,
"total_samples": len(shuffled_samples),
"created_at": datetime.now().isoformat(),
"data_path": data_path,
}
# Count samples per stage
stage_counts = {}
for sample in shuffled_samples:
label = sample["label"]
stage_counts[label] = stage_counts.get(label, 0) + 1
metadata["stage_distribution"] = stage_counts
output_data = {
"metadata": metadata,
"samples": shuffled_samples,
}
with open(output_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"[SHUFFLE] Saved shuffled data to: {output_path}")
print(f" User: {user_str}, Seed: {seed}")
print(f" Total samples: {len(shuffled_samples)}")
print(f" Stage distribution: {stage_counts}")
return shuffled_samples
def load_shuffled_data(
user_id: int,
seed: int,
output_dir: str = None,
data_path: str = None,
) -> List[Dict[str, Any]]:
"""
Load existing shuffled data from JSON file.
If file doesn't exist, create it by shuffling the data.
Args:
user_id: User ID to load
seed: Random seed used for shuffling
output_dir: Directory containing shuffled data files
data_path: Path to original data (used if shuffled file doesn't exist)
Returns:
List of shuffled samples
"""
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
if output_dir is None and data_path is not None:
output_dir = os.path.join(data_path, "shuffled")
if output_dir is None:
raise ValueError("Either output_dir or data_path must be provided")
file_path = os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
print(f"[SHUFFLE] Loaded existing shuffled data: {file_path}")
print(f" Total samples: {data['metadata']['total_samples']}")
return data["samples"]
else:
if data_path is None:
raise FileNotFoundError(f"Shuffled data not found: {file_path}")
print(f"[SHUFFLE] Shuffled data not found, creating new...")
return shuffle_user_data(
user_id=user_id,
seed=seed,
data_path=data_path,
output_dir=output_dir,
save_to_file=True,
)
def get_shuffled_file_path(
user_id: int,
seed: int,
data_path: str,
) -> str:
"""
Get the path to the shuffled data file.
Args:
user_id: User ID
seed: Random seed
data_path: Base data path
Returns:
Path to the shuffled data JSON file
"""
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
output_dir = os.path.join(data_path, "shuffled")
return os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
def verify_shuffle_consistency(
user_id: int,
seeds: List[int],
data_path: str,
) -> bool:
"""
Verify that shuffled data files exist and are consistent.
Args:
user_id: User ID to verify
seeds: List of seeds to verify
data_path: Base data path
Returns:
True if all files exist and have same sample count
"""
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
output_dir = os.path.join(data_path, "shuffled")
sample_counts = []
for seed in seeds:
file_path = os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
if not os.path.exists(file_path):
print(f"[VERIFY] Missing: {file_path}")
return False
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
sample_counts.append(data["metadata"]["total_samples"])
# All should have same count
if len(set(sample_counts)) != 1:
print(f"[VERIFY] Inconsistent sample counts: {sample_counts}")
return False
print(f"[VERIFY] User {user_str}: All {len(seeds)} seed files verified ({sample_counts[0]} samples each)")
return True
class ShuffledDataLoader:
"""
DataLoader that uses pre-shuffled data for consistent experiment comparison.
"""
def __init__(
self,
data_path: str,
user_id: int,
seed: int,
example_pool: str = "out",
):
"""
Initialize shuffled data loader.
Args:
data_path: Path to SleepEDF data
user_id: User ID (e.g., 5 or 10)
seed: Shuffle seed
example_pool: "out" (different users) or "in" (same user)
"""
self.data_path = data_path
self.user_id = user_id
self.seed = seed
self.example_pool = example_pool
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
self.user_str = user_str
# Load metadata
info_path = os.path.join(data_path, "info.json")
if not os.path.exists(info_path):
raise FileNotFoundError(f"Info file not found: {info_path}")
with open(info_path, "r", encoding="utf-8") as f:
self.metadata = json.load(f)
# Load shuffled test data
self.shuffled_samples = load_shuffled_data(
user_id=user_id,
seed=seed,
data_path=data_path,
)
# Load example dataset (same as original DataLoader)
self.example_dataset = datasets.Dataset.from_list([])
users = glob(os.path.join(data_path, "*"))
users = [os.path.basename(p) for p in users if os.path.isdir(p)]
users = [u for u in users if u not in ["info.json", "shuffled"]]
for user in users:
if example_pool == "out" and user == user_str:
continue
if example_pool == "in" and user != user_str:
continue
example_path = os.path.join(data_path, user, "1")
if os.path.exists(example_path):
user_dataset = datasets.load_from_disk(example_path)
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
# Shuffle example dataset with fixed seed
self.example_dataset = self.example_dataset.shuffle(seed=0)
print(f"[ShuffledDataLoader] User: {user_str}, Seed: {seed}")
print(f" Test samples: {len(self.shuffled_samples)}")
print(f" Example samples: {len(self.example_dataset)}")
def __len__(self):
return len(self.shuffled_samples)
def __getitem__(self, idx):
return self.shuffled_samples[idx]
def __iter__(self):
for sample in self.shuffled_samples:
yield sample
def get_examples(self):
return self.example_dataset
def get_metadata(self):
return self.metadata
def get_sensor_info(self):
return self.metadata["feature"]
def get_task_info(self):
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
classes_info = "\n".join(classes_info)
task_info += f"**Classes**:\n{classes_info}"
return task_info
def get_classes_info(self):
return list(self.metadata["class"].keys())
def prepare_all_shuffled_data(
data_path: str,
users: List[int] = [5, 10],
seeds: List[int] = [42, 123, 456],
) -> None:
"""
Prepare all shuffled data files for the experiment.
Args:
data_path: Path to SleepEDF data
users: List of user IDs
seeds: List of shuffle seeds
"""
print("=" * 60)
print("Preparing Shuffled Data for Experiment")
print("=" * 60)
for user_id in users:
for seed in seeds:
print(f"\nProcessing User {user_id}, Seed {seed}...")
shuffle_user_data(
user_id=user_id,
seed=seed,
data_path=data_path,
save_to_file=True,
)
print("\n" + "=" * 60)
print("Verification")
print("=" * 60)
for user_id in users:
verify_shuffle_consistency(user_id, seeds, data_path)
print("\n[DONE] All shuffled data prepared.")
if __name__ == "__main__":
import fire
def main(
data_path: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
users = "5,10",
seeds = "42,123,456",
):
"""
Prepare shuffled data for experiments.
Args:
data_path: Path to SleepEDF data
users: Comma-separated user IDs or list/tuple
seeds: Comma-separated shuffle seeds or list/tuple
"""
# Handle both string and tuple/list inputs
if isinstance(users, str):
user_list = [int(u.strip()) for u in users.split(",")]
elif isinstance(users, (list, tuple)):
user_list = [int(u) for u in users]
else:
user_list = [int(users)]
if isinstance(seeds, str):
seed_list = [int(s.strip()) for s in seeds.split(",")]
elif isinstance(seeds, (list, tuple)):
seed_list = [int(s) for s in seeds]
else:
seed_list = [int(seeds)]
prepare_all_shuffled_data(
data_path=data_path,
users=user_list,
seeds=seed_list,
)
fire.Fire(main)

View File

@@ -1,463 +0,0 @@
"""
Confidence-based Queue Policy Experiment Runner
This module implements the CONFIDENCE-based queue update policy:
- Queue is updated based on model confidence scores
- Higher confidence examples are retained in the queue
- Lower confidence examples are evicted
Usage:
python -m sc.run_confidence --user_id=5 --shuffle_seed=42
python -m sc.run_confidence --user_id=10 --shuffle_seed=123
"""
import os
import sys
import asyncio
import yaml
import json
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Any, Optional
from sklearn.metrics import f1_score, precision_score, recall_score
from fire import Fire
# Add project root to path for relative imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.queue import Queue
from sc.core.agent_pool import AgentPool
from sc import debug_log
log = debug_log.log
async def run_confidence_experiment(
dataloader: ShuffledDataLoader,
model_pool,
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> List[Dict[str, Any]]:
"""
Run confidence-based queue policy experiment.
Queue Update Policy:
- After each inference, rank queue elements by confidence score
- Evict the lowest confidence element
- Add a new random ICL example set
Args:
dataloader: ShuffledDataLoader with pre-shuffled test data
model_pool: Async model pool for LLM inference
config: Experiment configuration
user_id: User ID being processed
shuffle_seed: Shuffle seed for reproducibility
Returns:
List of result dictionaries
"""
# Set seeds for reproducibility within experiment
random.seed(shuffle_seed)
np.random.seed(shuffle_seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
log(f"[ERROR] No examples found for user {user_id}")
return []
# Build class_indices for Queue initialization
class_indices = {}
for idx, example in enumerate(example_dataset):
label = example["label"]
if label not in class_indices:
class_indices[label] = []
class_indices[label].append(idx)
# Initialize Queue
queue_size = config.get("queue_size", 5)
ex_queue = Queue(class_indices, queue_size)
# Tracking variables
results = []
cumulative_correct = 0
window_size = config.get("tracking_window", 20)
recent_results = []
confidence_history = []
consistency_history = []
all_predictions = []
all_ground_truths = []
debug_log.log_policy_experiment_header(
"CONFIDENCE-BASED",
user_id,
shuffle_seed,
len(dataloader),
queue_size,
)
for processed_count, sample in enumerate(dataloader):
ex_queue.set_current_time(processed_count)
# Log queue state before processing
debug_log.log_policy_queue_state("CONFIDENCE", processed_count, user_id, ex_queue)
# Create agent pool
agent_pool = AgentPool(log_path=config["log_path"])
try:
for queue_idx, ex_idcs in enumerate(ex_queue):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
model_pool=model_pool,
task_info=dataloader.get_task_info(),
classes_info=dataloader.get_classes_info(),
sensor_info=dataloader.get_sensor_info(),
sample=sample,
examples=examples,
log_path=config["log_path"],
)
agent_pool.add_agent(agent)
except Exception as e:
log(f"[ERROR] Failed to create agents: {e}")
import traceback
traceback.print_exc()
continue
if len(agent_pool.agents) == 0:
log(f"[WARN] No agents created for sample {processed_count}")
continue
# Run parallel interpretation
try:
interpretation_result = await agent_pool.run_parallel_interpretation()
except Exception as e:
log(f"[ERROR] Interpretation failed: {e}")
import traceback
traceback.print_exc()
continue
if interpretation_result is None:
log(f"[WARN] Interpretation failed for sample {processed_count}")
continue
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
# Evaluate result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
confidence_history.append(avg_confidence)
consistency_history.append(consistency)
all_predictions.append(answer)
all_ground_truths.append(ground_truth)
# Performance logging
debug_log.log_policy_result(
"CONFIDENCE",
processed_count,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
)
# CONFIDENCE-BASED Queue Update
if responses:
# Build confidence map from agent responses
confidence_map = {}
for idx, response in responses.items():
confidence_map[idx] = response.get("CONFIDENCE", 0.0)
debug_log.log_confidence_map(responses, confidence_map)
# Update queue by confidence (evict lowest, add new random)
ex_queue.update_by_confidence(confidence_map)
ex_queue.increment_usage(list(responses.keys()))
debug_log.log_policy_queue_state_after("Confidence", processed_count, ex_queue)
# Store result
result = {
"sample_idx": processed_count,
"original_idx": sample.get("original_idx", processed_count),
"shuffled_idx": sample.get("shuffled_idx", processed_count),
"answer": answer,
"ground_truth": ground_truth,
"is_correct": is_correct,
"confidence": avg_confidence,
"consistency": consistency,
"cumulative_accuracy": cumulative_accuracy,
"window_accuracy": window_accuracy,
"experiment_type": "confidence",
"user_id": user_id,
"shuffle_seed": shuffle_seed,
}
results.append(result)
# Final statistics
survival_summary = ex_queue.get_survival_summary()
debug_log.log_final_policy_survival("CONFIDENCE", user_id, shuffle_seed, survival_summary)
if results:
results[-1]["queue_survival_stats"] = survival_summary
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
return results
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
"""Compute comprehensive experiment statistics."""
if not results:
return {}
# Overall metrics
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Confidence and consistency averages
confidences = [r.get("confidence", 0) for r in results]
consistencies = [r.get("consistency", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-stage accuracy
stage_correct = {}
stage_total = {}
for r in results:
gt = r.get("ground_truth", "UNKNOWN")
stage_total[gt] = stage_total.get(gt, 0) + 1
if r.get("is_correct", False):
stage_correct[gt] = stage_correct.get(gt, 0) + 1
stage_accuracy = {}
for stage in stages:
if stage in stage_total:
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
else:
stage_accuracy[stage] = 0.0
# F1 Score and Macro metrics
predictions = [r.get("answer", "") for r in results]
ground_truths = [r.get("ground_truth", "") for r in results]
try:
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
except Exception:
macro_f1 = macro_precision = macro_recall = 0.0
# Temporal analysis
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = second_half_acc = accuracy
improvement = 0
# Learning curve (every 10 samples)
learning_curve = []
for i in range(0, len(results), 10):
chunk = results[:i+10]
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
learning_curve.append({
"sample_idx": min(i+10, len(results)),
"cumulative_accuracy": chunk_acc,
})
# Convergence speed (90% of final accuracy)
final_accuracy = accuracy
convergence_threshold = 0.9 * final_accuracy
convergence_idx = None
running_correct = 0
for i, r in enumerate(results):
running_correct += 1 if r.get("is_correct", False) else 0
running_acc = running_correct / (i + 1)
if running_acc >= convergence_threshold:
convergence_idx = i
break
# Queue statistics
queue_stats = {}
for r in results:
if "queue_survival_stats" in r:
queue_stats = r["queue_survival_stats"]
break
return {
"experiment_type": "confidence",
"user_id": results[0].get("user_id") if results else None,
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
"total_samples": total,
"correct": correct,
"accuracy": accuracy,
"avg_confidence": float(avg_confidence),
"avg_consistency": float(avg_consistency),
"stage_accuracy": stage_accuracy,
"stage_sample_counts": stage_total,
"macro_f1": float(macro_f1),
"macro_precision": float(macro_precision),
"macro_recall": float(macro_recall),
"temporal_analysis": {
"first_half_accuracy": first_half_acc,
"second_half_accuracy": second_half_acc,
"improvement": improvement,
},
"learning_curve": learning_curve,
"convergence_idx": convergence_idx,
"queue_stats": queue_stats,
}
def save_results(
results: List[Dict[str, Any]],
stats: Dict[str, Any],
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> None:
"""Save experiment results and statistics."""
log_path = config["log_path"]
# Create user/seed specific directory
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
os.makedirs(output_dir, exist_ok=True)
# Save statistics
stats_path = os.path.join(output_dir, "statistics.json")
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Statistics: {stats_path}")
# Save results
results_path = os.path.join(output_dir, "results.json")
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Results: {results_path}")
# Save config
config_path = os.path.join(output_dir, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
log(f"[SAVE] Config: {config_path}")
def main(
config_path: str = "sc/config/experiment_confidence.yaml",
user_id: int = 5,
shuffle_seed: int = 42,
) -> None:
"""
Run Confidence-based Queue Policy experiment.
Args:
config_path: Path to YAML configuration file
user_id: User ID to process (5 or 10)
shuffle_seed: Shuffle seed for data order (42, 123, or 456)
Example:
python -m sc.run_confidence --user_id=5 --shuffle_seed=42
python -m sc.run_confidence --user_id=10 --shuffle_seed=123
"""
log(f"[MAIN] Loading config: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
# Override with CLI arguments
config["user_id"] = user_id
config["shuffle_seed"] = shuffle_seed
# Create unique log path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
os.makedirs(config["log_path"], exist_ok=True)
# Print experiment info
debug_log.log_policy_main_header(
"CONFIDENCE-BASED",
user_id,
shuffle_seed,
config.get("queue_size", 5),
len(config.get("models", [])),
config["log_path"],
)
# Load shuffled data
debug_log.log_policy_loading_data(user_id, shuffle_seed)
dataloader = ShuffledDataLoader(
data_path=config["data_path"],
user_id=user_id,
seed=shuffle_seed,
example_pool=config.get("example_pool", "out"),
)
# Load models
debug_log.log_policy_loading_models()
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
num_ctx=config.get("num_ctx", 15000),
)
# Run experiment
debug_log.log_policy_start(label="experiment")
results = asyncio.run(run_confidence_experiment(
dataloader=dataloader,
model_pool=model_pool,
config=config,
user_id=user_id,
shuffle_seed=shuffle_seed,
))
# Compute statistics
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
stats = compute_statistics(results, stages)
# Print final summary
temporal = stats.get("temporal_analysis", {})
debug_log.log_policy_complete_summary(
"CONFIDENCE",
user_id,
shuffle_seed,
stats,
stats.get("stage_accuracy", {}),
stats.get("stage_sample_counts", {}),
temporal,
)
# Save results
save_results(results, stats, config, user_id, shuffle_seed)
log(f"\n[MAIN] Results saved to: {config['log_path']}")
if __name__ == "__main__":
Fire(main)

View File

@@ -1,473 +0,0 @@
"""
Consistency-based Queue Policy Experiment Runner
This module implements the CONSISTENCY-based queue update policy:
- Queue is updated based on SC consensus/agreement scores
- Higher consistency (more agents agree) examples are retained
- Lower consistency examples are evicted
Consistency Score = (Number of agents agreeing with majority) / (Total agents)
Usage:
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
"""
import os
import sys
import asyncio
import yaml
import json
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Any, Optional
from sklearn.metrics import f1_score, precision_score, recall_score
from fire import Fire
# Add project root to path for relative imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.queue import Queue
from sc.core.agent_pool import AgentPool
from sc import debug_log
log = debug_log.log
async def run_consistency_experiment(
dataloader: ShuffledDataLoader,
model_pool,
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> List[Dict[str, Any]]:
"""
Run consistency-based queue policy experiment.
Queue Update Policy:
- After each inference, calculate per-agent consistency
- Consistency = how many other agents agree with this agent's answer
- Rank queue elements by consistency score
- Evict the lowest consistency element
- Add a new random ICL example set
Args:
dataloader: ShuffledDataLoader with pre-shuffled test data
model_pool: Async model pool for LLM inference
config: Experiment configuration
user_id: User ID being processed
shuffle_seed: Shuffle seed for reproducibility
Returns:
List of result dictionaries
"""
# Set seeds for reproducibility within experiment
random.seed(shuffle_seed)
np.random.seed(shuffle_seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
log(f"[ERROR] No examples found for user {user_id}")
return []
# Build class_indices for Queue initialization
class_indices = {}
for idx, example in enumerate(example_dataset):
label = example["label"]
if label not in class_indices:
class_indices[label] = []
class_indices[label].append(idx)
# Initialize Queue
queue_size = config.get("queue_size", 5)
ex_queue = Queue(class_indices, queue_size)
# Tracking variables
results = []
cumulative_correct = 0
window_size = config.get("tracking_window", 20)
recent_results = []
confidence_history = []
consistency_history = []
all_predictions = []
all_ground_truths = []
debug_log.log_policy_experiment_header(
"CONSISTENCY-BASED",
user_id,
shuffle_seed,
len(dataloader),
queue_size,
)
for processed_count, sample in enumerate(dataloader):
ex_queue.set_current_time(processed_count)
# Log queue state before processing
debug_log.log_policy_queue_state("CONSISTENCY", processed_count, user_id, ex_queue)
# Create agent pool
agent_pool = AgentPool(log_path=config["log_path"])
try:
for queue_idx, ex_idcs in enumerate(ex_queue):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
model_pool=model_pool,
task_info=dataloader.get_task_info(),
classes_info=dataloader.get_classes_info(),
sensor_info=dataloader.get_sensor_info(),
sample=sample,
examples=examples,
log_path=config["log_path"],
)
agent_pool.add_agent(agent)
except Exception as e:
log(f"[ERROR] Failed to create agents: {e}")
import traceback
traceback.print_exc()
continue
if len(agent_pool.agents) == 0:
log(f"[WARN] No agents created for sample {processed_count}")
continue
# Run parallel interpretation
try:
interpretation_result = await agent_pool.run_parallel_interpretation()
except Exception as e:
log(f"[ERROR] Interpretation failed: {e}")
import traceback
traceback.print_exc()
continue
if interpretation_result is None:
log(f"[WARN] Interpretation failed for sample {processed_count}")
continue
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
# Evaluate result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
confidence_history.append(avg_confidence)
consistency_history.append(consistency)
all_predictions.append(answer)
all_ground_truths.append(ground_truth)
# Performance logging
debug_log.log_policy_result(
"CONSISTENCY",
processed_count,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
)
# CONSISTENCY-BASED Queue Update
if responses:
# Calculate per-agent consistency: how many other agents agree
all_answers = [r.get("ANSWER") for r in responses.values()]
consistency_map = {}
for idx, response in responses.items():
agent_answer = response.get("ANSWER")
# Consistency = ratio of agents (including self) that agree
agreement_count = all_answers.count(agent_answer)
agent_consistency = agreement_count / len(all_answers) if all_answers else 0
consistency_map[idx] = agent_consistency
debug_log.log_consistency_map(responses, consistency_map)
# Update queue by consistency (reuses confidence method with consistency scores)
ex_queue.update_by_confidence(consistency_map)
ex_queue.increment_usage(list(responses.keys()))
debug_log.log_policy_queue_state_after("Consistency", processed_count, ex_queue)
# Store result
result = {
"sample_idx": processed_count,
"original_idx": sample.get("original_idx", processed_count),
"shuffled_idx": sample.get("shuffled_idx", processed_count),
"answer": answer,
"ground_truth": ground_truth,
"is_correct": is_correct,
"confidence": avg_confidence,
"consistency": consistency,
"cumulative_accuracy": cumulative_accuracy,
"window_accuracy": window_accuracy,
"experiment_type": "consistency",
"user_id": user_id,
"shuffle_seed": shuffle_seed,
}
results.append(result)
# Final statistics
survival_summary = ex_queue.get_survival_summary()
debug_log.log_final_policy_survival("CONSISTENCY", user_id, shuffle_seed, survival_summary)
if results:
results[-1]["queue_survival_stats"] = survival_summary
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
return results
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
"""Compute comprehensive experiment statistics."""
if not results:
return {}
# Overall metrics
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Confidence and consistency averages
confidences = [r.get("confidence", 0) for r in results]
consistencies = [r.get("consistency", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-stage accuracy
stage_correct = {}
stage_total = {}
for r in results:
gt = r.get("ground_truth", "UNKNOWN")
stage_total[gt] = stage_total.get(gt, 0) + 1
if r.get("is_correct", False):
stage_correct[gt] = stage_correct.get(gt, 0) + 1
stage_accuracy = {}
for stage in stages:
if stage in stage_total:
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
else:
stage_accuracy[stage] = 0.0
# F1 Score and Macro metrics
predictions = [r.get("answer", "") for r in results]
ground_truths = [r.get("ground_truth", "") for r in results]
try:
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
except Exception:
macro_f1 = macro_precision = macro_recall = 0.0
# Temporal analysis
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = second_half_acc = accuracy
improvement = 0
# Learning curve (every 10 samples)
learning_curve = []
for i in range(0, len(results), 10):
chunk = results[:i+10]
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
learning_curve.append({
"sample_idx": min(i+10, len(results)),
"cumulative_accuracy": chunk_acc,
})
# Convergence speed (90% of final accuracy)
final_accuracy = accuracy
convergence_threshold = 0.9 * final_accuracy
convergence_idx = None
running_correct = 0
for i, r in enumerate(results):
running_correct += 1 if r.get("is_correct", False) else 0
running_acc = running_correct / (i + 1)
if running_acc >= convergence_threshold:
convergence_idx = i
break
# Queue statistics
queue_stats = {}
for r in results:
if "queue_survival_stats" in r:
queue_stats = r["queue_survival_stats"]
break
return {
"experiment_type": "consistency",
"user_id": results[0].get("user_id") if results else None,
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
"total_samples": total,
"correct": correct,
"accuracy": accuracy,
"avg_confidence": float(avg_confidence),
"avg_consistency": float(avg_consistency),
"stage_accuracy": stage_accuracy,
"stage_sample_counts": stage_total,
"macro_f1": float(macro_f1),
"macro_precision": float(macro_precision),
"macro_recall": float(macro_recall),
"temporal_analysis": {
"first_half_accuracy": first_half_acc,
"second_half_accuracy": second_half_acc,
"improvement": improvement,
},
"learning_curve": learning_curve,
"convergence_idx": convergence_idx,
"queue_stats": queue_stats,
}
def save_results(
results: List[Dict[str, Any]],
stats: Dict[str, Any],
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> None:
"""Save experiment results and statistics."""
log_path = config["log_path"]
# Create user/seed specific directory
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
os.makedirs(output_dir, exist_ok=True)
# Save statistics
stats_path = os.path.join(output_dir, "statistics.json")
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Statistics: {stats_path}")
# Save results
results_path = os.path.join(output_dir, "results.json")
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Results: {results_path}")
# Save config
config_path = os.path.join(output_dir, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
log(f"[SAVE] Config: {config_path}")
def main(
config_path: str = "sc/config/experiment_consistency.yaml",
user_id: int = 5,
shuffle_seed: int = 42,
) -> None:
"""
Run Consistency-based Queue Policy experiment.
Args:
config_path: Path to YAML configuration file
user_id: User ID to process (5 or 10)
shuffle_seed: Shuffle seed for data order (42, 123, or 456)
Example:
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
"""
log(f"[MAIN] Loading config: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
# Override with CLI arguments
config["user_id"] = user_id
config["shuffle_seed"] = shuffle_seed
# Create unique log path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
os.makedirs(config["log_path"], exist_ok=True)
# Print experiment info
debug_log.log_policy_main_header(
"CONSISTENCY-BASED",
user_id,
shuffle_seed,
config.get("queue_size", 5),
len(config.get("models", [])),
config["log_path"],
)
# Load shuffled data
debug_log.log_policy_loading_data(user_id, shuffle_seed)
dataloader = ShuffledDataLoader(
data_path=config["data_path"],
user_id=user_id,
seed=shuffle_seed,
example_pool=config.get("example_pool", "out"),
)
# Load models
debug_log.log_policy_loading_models()
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
num_ctx=config.get("num_ctx", 15000),
)
# Run experiment
debug_log.log_policy_start(label="experiment")
results = asyncio.run(run_consistency_experiment(
dataloader=dataloader,
model_pool=model_pool,
config=config,
user_id=user_id,
shuffle_seed=shuffle_seed,
))
# Compute statistics
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
stats = compute_statistics(results, stages)
# Print final summary
temporal = stats.get("temporal_analysis", {})
debug_log.log_policy_complete_summary(
"CONSISTENCY",
user_id,
shuffle_seed,
stats,
stats.get("stage_accuracy", {}),
stats.get("stage_sample_counts", {}),
temporal,
)
# Save results
save_results(results, stats, config, user_id, shuffle_seed)
log(f"\n[MAIN] Results saved to: {config['log_path']}")
if __name__ == "__main__":
Fire(main)

View File

@@ -1,467 +0,0 @@
"""
Self-Consistency Experiment Runner for Sleep Stage Classification
This module implements Self-Consistency methodology for sleep stage classification:
- Sample N times with the same prompt
- Use majority voting for final answer
- Support confidence-based tie-breaking
Usage:
python -m sc.run_sc sc/config/sleepedf_sc.yaml
"""
import os
import sys
import asyncio
import yaml
import json
import numpy as np
from datetime import datetime
from typing import List, Dict, Any
import time
from glob import glob
from fire import Fire
# Add project root to path for relative imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.core.data_loader import DataLoader
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.queue import Queue
from sc.core.agent_pool import AgentPool
from sc import debug_log
async def run_single_task(
seed: int,
dataloader: DataLoader,
model_pool,
config: Dict[str, Any],
user_id: str = None,
) -> Dict[str, Any]:
"""
Execute a single classification task with Self-Consistency.
Args:
task: Task dictionary containing sample, examples, and metadata
model_pool: Async model pool for LLM inference
num_sc_samples: Number of Self-Consistency samples
Returns:
Result dictionary with answer, confidence, consistency, and metrics
"""
np.random.seed(seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
debug_log.warn_no_examples()
return []
# Build class_indices: Dict[str, List[int]] for Queue initialization
class_indices = {}
for idx, example in enumerate(example_dataset):
label = example["label"]
if label not in class_indices:
class_indices[label] = []
class_indices[label].append(idx)
ex_queue = Queue(class_indices, config["queue_size"])
sample_rate = config.get("sample_rate", 1)
results = []
cumulative_correct = 0
window_size = config.get("tracking_window", 10) # Recent window size for accuracy tracking
recent_results = [] # Windowed recent results for accuracy calculation
confidence_history = [] # Confidence history for tracking
processed_count = 0
for i, sample in enumerate(dataloader):
if sample_rate > 1 and i % sample_rate != 0:
continue
ex_queue.set_current_time(processed_count) # Track current sample index for survival time
user_info = f"User: {user_id}" if user_id else "Unknown User"
debug_log.log_queue_state_before(processed_count, user_info, ex_queue, config)
agent_pool = AgentPool(log_path=config["log_path"])
for queue_idx, ex_idcs in enumerate(ex_queue):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
task_info = dataloader.get_task_info()
sensor_info = dataloader.get_sensor_info()
classes_info = dataloader.get_classes_info()
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
model_pool=model_pool,
task_info=task_info,
classes_info=classes_info,
sensor_info=sensor_info,
sample=sample,
examples=examples,
log_path=config["log_path"],
)
agent_pool.add_agent(agent)
# Check if any agents were added
if len(agent_pool.agents) == 0:
debug_log.warn_no_agents(processed_count)
continue
interpretation_result = await agent_pool.run_parallel_interpretation()
# Handle case where interpretation failed
if interpretation_result is None:
debug_log.warn_interpretation_failed(processed_count)
continue
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
# Confidence history
confidence_history.append(avg_confidence)
avg_confidence_so_far = sum(confidence_history) / len(confidence_history)
debug_log.log_tracking(
processed_count,
user_info,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
avg_confidence_so_far,
config,
)
# Update queue based on Confidence (Priority Queue)
if responses:
# Create confidence map {queue_idx: confidence}
confidence_map = {idx: r.get("CONFIDENCE", 0) for idx, r in responses.items()}
ex_queue.update_by_confidence(confidence_map)
ex_queue.increment_usage(list(responses.keys()))
debug_log.log_queue_state_after(processed_count, ex_queue, config)
elif queue_idcs:
debug_log.warn_no_responses(processed_count)
result = {
"sample_idx": processed_count,
"answer": answer,
"ground_truth": ground_truth,
"is_correct": is_correct,
"confidence": avg_confidence,
"consistency": consistency,
"cumulative_accuracy": cumulative_accuracy,
"window_accuracy": window_accuracy,
"avg_confidence_so_far": avg_confidence_so_far,
}
results.append(result)
processed_count += 1
# Final Queue survival statistics
survival_summary = ex_queue.get_survival_summary()
debug_log.log_final_queue_stats(user_info, survival_summary, config)
# Add Queue statistics to results (in the last result)
if results:
results[-1]["queue_survival_stats"] = survival_summary
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
return results
async def run_parallel(
config: Dict[str, Any],
model_pool,
) -> List[Dict[str, Any]]:
"""
Execute all classification tasks in parallel.
Args:
tasks: List of task dictionaries
model_pool: Async model pool for LLM inference
config: Experiment configuration
Returns:
List of result dictionaries
"""
data_path = config["data_path"]
user_paths = glob(os.path.join(data_path, "*"))
# Filter to only include directories (exclude files like info.json)
users = [os.path.basename(p) for p in user_paths if os.path.isdir(p) and os.path.basename(p) != "info.json"]
if not users:
debug_log.warn_no_user_dirs(data_path)
return []
debug_log.log_found_users(users, config)
seeds = range(config.get("num_seeds", 1))
tasks = []
for user in users[:1]: # <- User Selection for testing
# for user in users:
for seed in seeds:
np.random.seed(seed)
dataloader = DataLoader(
data_path=data_path,
user_id=user,
example_pool=config.get("example_pool", "out"),
continuous=config.get("continuous", True),
)
# Check if dataloader was properly initialized
if not hasattr(dataloader, 'test_dataset') or len(dataloader) == 0:
debug_log.warn_skip_user_no_test_data(user)
continue
if len(dataloader.get_examples()) == 0:
debug_log.warn_skip_user_no_example_data(user)
continue
task = asyncio.create_task(
run_single_task(
seed,
dataloader,
model_pool,
config,
user_id=user,
)
)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
return results
# =============================================================================
# Statistics and Results
# =============================================================================
def compute_statistics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Compute experiment statistics from results.
Args:
results: List of result dictionaries
Returns:
Statistics dictionary containing accuracy, confidence, consistency metrics
"""
if not results:
return {}
# Overall accuracy
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Average confidence
confidences = [r.get("confidence", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
# Average consistency
consistencies = [r.get("consistency", 0) for r in results]
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-class accuracy
class_correct = {}
class_total = {}
for r in results:
gt = r.get("ground_truth", "UNKNOWN")
class_total[gt] = class_total.get(gt, 0) + 1
if r.get("is_correct", False):
class_correct[gt] = class_correct.get(gt, 0) + 1
class_accuracy = {
cls: class_correct.get(cls, 0) / class_total[cls]
for cls in class_total
}
# High consistency accuracy analysis
high_consistency_results = [r for r in results if r.get("consistency", 0) >= 0.8]
high_consistency_accuracy = (
sum(1 for r in high_consistency_results if r.get("is_correct", False))
/ len(high_consistency_results)
) if high_consistency_results else 0
# ================================================================
# TIME DEBUGGING
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = 0
second_half_acc = 0
improvement = 0
quartile_accs = []
quartile_size = len(results) // 4 if len(results) >= 4 else len(results)
if quartile_size > 0:
for q in range(4):
start = q * quartile_size
end = start + quartile_size if q < 3 else len(results)
quartile = results[start:end]
if quartile:
q_acc = sum(1 for r in quartile if r.get("is_correct", False)) / len(quartile)
quartile_accs.append(q_acc)
if len(confidences) > 1:
first_half_conf = np.mean(confidences[:mid_point]) if mid_point > 0 else 0
second_half_conf = np.mean(confidences[mid_point:]) if mid_point > 0 else 0
conf_improvement = second_half_conf - first_half_conf
else:
first_half_conf = avg_confidence
second_half_conf = avg_confidence
conf_improvement = 0
queue_survival_stats = {}
for r in results:
if "queue_survival_stats" in r:
queue_survival_stats = r["queue_survival_stats"]
break
# ================================================================
return {
"total_samples": total,
"correct": correct,
"accuracy": accuracy,
"avg_confidence": float(avg_confidence),
"avg_consistency": float(avg_consistency),
"class_accuracy": class_accuracy,
"high_consistency_accuracy": high_consistency_accuracy,
"high_consistency_samples": len(high_consistency_results),
"temporal_analysis": {
"first_half_accuracy": first_half_acc,
"second_half_accuracy": second_half_acc,
"accuracy_improvement": improvement,
"quartile_accuracies": quartile_accs,
"first_half_confidence": float(first_half_conf),
"second_half_confidence": float(second_half_conf),
"confidence_improvement": float(conf_improvement),
},
"queue_stats": queue_survival_stats,
}
def save_results(
results: List[Dict[str, Any]],
stats: Dict[str, Any],
config: Dict[str, Any]
) -> None:
"""
Save experiment results and statistics to files.
Args:
results: List of result dictionaries
stats: Statistics dictionary
config: Experiment configuration
"""
log_path = config["log_path"]
os.makedirs(log_path, exist_ok=True)
# Save statistics
stats_path = os.path.join(log_path, "statistics.json")
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
debug_log.log_save_statistics(stats_path, config)
# Save all results
results_to_save = []
for r in results:
r_copy = {k: v for k, v in r.items() if k != "all_responses"}
results_to_save.append(r_copy)
results_path = os.path.join(log_path, "all_results.json")
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results_to_save, f, indent=2, ensure_ascii=False)
debug_log.log_save_results(results_path, config)
# Save configuration for reproducibility
config_path = os.path.join(log_path, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
debug_log.log_save_config(config_path, config)
# =============================================================================
# CLI Commands
# =============================================================================
def main(config_path: str) -> None:
"""
Run experiment.
Args:
config_path: Path to YAML configuration file
Example:
python -m sc.run_sc sc/config/sleepedf_sc.yaml
"""
debug_log.log_main_loading_config(config_path)
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
# Add timestamp to log path for unique experiment runs
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if "log_path" in config:
config["log_path"] = f"{config['log_path']}_{timestamp}"
# Print experiment configuration
debug_log.log_main_config(config, config.get("debug", True))
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
num_ctx=config.get("num_ctx", 15000),
)
# Run experiments
debug_log.log_main_start(config)
all_results = asyncio.run(run_parallel(config, model_pool))
# Flatten results: run_parallel returns list of lists (one per user/seed)
# Each element is either a list of results or an exception
flattened_results = []
for result in all_results:
if isinstance(result, Exception):
debug_log.error_task_failed(result)
continue
if isinstance(result, list):
flattened_results.extend(result)
else:
flattened_results.append(result)
debug_log.log_total_results(len(flattened_results), config)
# Compute and display statistics
stats = compute_statistics(flattened_results)
debug_log.log_experiment_results(stats, config)
# Time-based analysis output
temporal = stats.get("temporal_analysis", {})
debug_log.log_temporal_analysis(temporal, config)
queue_stats = stats.get("queue_stats", {})
debug_log.log_queue_stats(queue_stats, config)
# Save results
save_results(flattened_results, stats, config)
debug_log.log_results_saved(config["log_path"], config)
if __name__ == "__main__":
Fire(main)

View File

@@ -1,557 +0,0 @@
"""
Queue Random Baseline Experiment Runner
This module implements the QUEUE RANDOM BASELINE policy:
- Queue structure is maintained (size=5)
- But every step, ALL 5 queue elements are replaced with fresh random samples
- This tests whether improvements come from queue structure itself
or from the cumulative learning effect
Purpose:
- Ablation study to distinguish:
1. Benefits from using 5 ICL examples (queue structure)
2. Benefits from retaining good examples over time (cumulative learning)
Key Difference from Other Policies:
- Confidence/Consistency: Queue updated by evicting lowest scoring examples
- Pure Random (no queue): Samples fresh examples each time, no structure
- Queue Random (this): Queue structure exists, but fully refreshed each step
Usage:
python -m sc.run_sc_queue_random --user_id=5 --shuffle_seed=42
python -m sc.run_sc_queue_random --user_id=10 --shuffle_seed=123
"""
import os
import sys
import asyncio
import yaml
import json
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Any, Optional
from sklearn.metrics import f1_score, precision_score, recall_score
from fire import Fire
# Add project root to path for relative imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.agent_pool import AgentPool
from sc import debug_log
log = debug_log.log
class QueueRandomSampler:
"""
Queue-based Random Sampler - maintains queue structure but refreshes all elements each step.
This sampler:
- Maintains a queue of size N (same as Confidence/Consistency policies)
- BUT replaces ALL N elements with fresh random samples at each step
- No memory/accumulation between steps
- Provides baseline to test if queue structure itself helps
"""
def __init__(self, example_dataset, queue_size: int = 5, seed: int = None):
"""
Initialize Queue Random sampler.
Args:
example_dataset: Dataset containing all available examples
queue_size: Size of the queue (number of ICL example sets)
seed: Random seed for reproducibility
"""
self.example_dataset = example_dataset
self.queue_size = queue_size
self.base_seed = seed
# Build class indices for balanced sampling
self.class_indices = {}
for idx, example in enumerate(example_dataset):
label = example["label"]
if label not in self.class_indices:
self.class_indices[label] = []
self.class_indices[label].append(idx)
self.classes = list(self.class_indices.keys())
self._step_counter = 0
# Initialize queue with random samples
self._queue = self._sample_all_new()
# Tracking for statistics (mimics Queue class interface)
self._total_refreshed = 0
debug_log.log_queue_random_sampler_init(
len(example_dataset),
self.classes,
queue_size,
)
def _sample_one_set(self) -> List[int]:
"""Sample one ICL example set (one example per class)."""
example_set = []
for cls in self.classes:
if self.class_indices[cls]:
idx = random.choice(self.class_indices[cls])
example_set.append(idx)
return example_set
def _sample_all_new(self) -> List[List[int]]:
"""Sample completely new queue contents."""
return [self._sample_one_set() for _ in range(self.queue_size)]
def refresh_all(self) -> None:
"""Replace ALL queue elements with fresh random samples."""
# Set seed for reproducibility but with variety per step
if self.base_seed is not None:
random.seed(self.base_seed + self._step_counter * 1000)
self._queue = self._sample_all_new()
self._total_refreshed += self.queue_size
self._step_counter += 1
def __iter__(self):
"""Iterate over queue contents (mimics Queue interface)."""
return iter(self._queue)
def __len__(self):
"""Return queue size."""
return len(self._queue)
def get_queue_state(self) -> List[List[int]]:
"""Get current queue state."""
return self._queue.copy()
def get_statistics(self) -> Dict[str, Any]:
"""Get sampling statistics."""
return {
"total_steps": self._step_counter,
"total_refreshed": self._total_refreshed,
"queue_size": self.queue_size,
"avg_refresh_per_step": self.queue_size, # Always full refresh
"policy": "queue_random",
}
async def run_queue_random_experiment(
dataloader: ShuffledDataLoader,
model_pool,
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> List[Dict[str, Any]]:
"""
Run Queue Random baseline experiment.
Queue Policy:
- Every step, ALL queue elements are replaced with fresh random samples
- No cumulative learning or retention of good examples
- Tests whether queue structure alone provides benefits
Args:
dataloader: ShuffledDataLoader with pre-shuffled test data
model_pool: Async model pool for LLM inference
config: Experiment configuration
user_id: User ID being processed
shuffle_seed: Shuffle seed for reproducibility
Returns:
List of result dictionaries
"""
# Set seeds for reproducibility
random.seed(shuffle_seed)
np.random.seed(shuffle_seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
log(f"[ERROR] No examples found for user {user_id}")
return []
# Initialize Queue Random Sampler (instead of regular Queue)
queue_size = config.get("queue_size", 5)
queue_sampler = QueueRandomSampler(
example_dataset=example_dataset,
queue_size=queue_size,
seed=shuffle_seed,
)
# Tracking variables
results = []
cumulative_correct = 0
window_size = config.get("tracking_window", 20)
recent_results = []
confidence_history = []
consistency_history = []
all_predictions = []
all_ground_truths = []
debug_log.log_policy_experiment_header(
"QUEUE RANDOM BASELINE",
user_id,
shuffle_seed,
len(dataloader),
queue_size,
policy_note=f"Policy: ALL {queue_size} queue elements refreshed EVERY step",
)
for processed_count, sample in enumerate(dataloader):
# CRITICAL: Refresh ALL queue elements before each inference
queue_sampler.refresh_all()
# Log queue state (all new random samples)
debug_log.log_queue_random_queue_state(processed_count, user_id, queue_sampler)
# Create agent pool
agent_pool = AgentPool(log_path=config["log_path"])
try:
for queue_idx, ex_idcs in enumerate(queue_sampler):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
model_pool=model_pool,
task_info=dataloader.get_task_info(),
classes_info=dataloader.get_classes_info(),
sensor_info=dataloader.get_sensor_info(),
sample=sample,
examples=examples,
log_path=config["log_path"],
)
agent_pool.add_agent(agent)
except Exception as e:
log(f"[ERROR] Failed to create agents: {e}")
import traceback
traceback.print_exc()
continue
if len(agent_pool.agents) == 0:
log(f"[WARN] No agents created for sample {processed_count}")
continue
# Run parallel interpretation
try:
interpretation_result = await agent_pool.run_parallel_interpretation()
except Exception as e:
log(f"[ERROR] Interpretation failed: {e}")
import traceback
traceback.print_exc()
continue
if interpretation_result is None:
log(f"[WARN] Interpretation failed for sample {processed_count}")
continue
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
# Evaluate result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
confidence_history.append(avg_confidence)
consistency_history.append(consistency)
all_predictions.append(answer)
all_ground_truths.append(ground_truth)
# Performance logging
debug_log.log_queue_random_result(
processed_count,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
)
# NO QUEUE UPDATE based on scores - just fresh random next time
# (This is the key difference from Confidence/Consistency policies)
# Store result
result = {
"sample_idx": processed_count,
"original_idx": sample.get("original_idx", processed_count),
"shuffled_idx": sample.get("shuffled_idx", processed_count),
"answer": answer,
"ground_truth": ground_truth,
"is_correct": is_correct,
"confidence": avg_confidence,
"consistency": consistency,
"cumulative_accuracy": cumulative_accuracy,
"window_accuracy": window_accuracy,
"experiment_type": "queue_random",
"user_id": user_id,
"shuffle_seed": shuffle_seed,
}
results.append(result)
# Final statistics
sampler_stats = queue_sampler.get_statistics()
debug_log.log_queue_random_stats(user_id, shuffle_seed, sampler_stats)
if results:
results[-1]["queue_random_stats"] = sampler_stats
return results
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
"""Compute comprehensive experiment statistics."""
if not results:
return {}
# Overall metrics
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Confidence and consistency averages
confidences = [r.get("confidence", 0) for r in results]
consistencies = [r.get("consistency", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-stage accuracy
stage_correct = {}
stage_total = {}
for r in results:
gt = r.get("ground_truth", "UNKNOWN")
stage_total[gt] = stage_total.get(gt, 0) + 1
if r.get("is_correct", False):
stage_correct[gt] = stage_correct.get(gt, 0) + 1
stage_accuracy = {}
for stage in stages:
if stage in stage_total:
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
else:
stage_accuracy[stage] = 0.0
# F1 Score and Macro metrics
predictions = [r.get("answer", "") for r in results]
ground_truths = [r.get("ground_truth", "") for r in results]
try:
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
except Exception:
macro_f1 = macro_precision = macro_recall = 0.0
# Temporal analysis
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = second_half_acc = accuracy
improvement = 0
# Learning curve (every 10 samples)
learning_curve = []
for i in range(0, len(results), 10):
chunk = results[:i+10]
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
learning_curve.append({
"sample_idx": min(i+10, len(results)),
"cumulative_accuracy": chunk_acc,
})
# Convergence speed
final_accuracy = accuracy
convergence_threshold = 0.9 * final_accuracy if final_accuracy > 0 else 0.5
convergence_idx = None
running_correct = 0
for i, r in enumerate(results):
running_correct += 1 if r.get("is_correct", False) else 0
running_acc = running_correct / (i + 1)
if running_acc >= convergence_threshold:
convergence_idx = i
break
# Queue random statistics
queue_stats = {}
for r in results:
if "queue_random_stats" in r:
queue_stats = r["queue_random_stats"]
break
return {
"experiment_type": "queue_random",
"user_id": results[0].get("user_id") if results else None,
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
"total_samples": total,
"correct": correct,
"accuracy": accuracy,
"avg_confidence": float(avg_confidence),
"avg_consistency": float(avg_consistency),
"stage_accuracy": stage_accuracy,
"stage_sample_counts": stage_total,
"macro_f1": float(macro_f1),
"macro_precision": float(macro_precision),
"macro_recall": float(macro_recall),
"temporal_analysis": {
"first_half_accuracy": first_half_acc,
"second_half_accuracy": second_half_acc,
"improvement": improvement,
},
"learning_curve": learning_curve,
"convergence_idx": convergence_idx,
"queue_stats": queue_stats,
}
def save_results(
results: List[Dict[str, Any]],
stats: Dict[str, Any],
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> None:
"""Save experiment results and statistics."""
log_path = config["log_path"]
# Create user/seed specific directory
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
os.makedirs(output_dir, exist_ok=True)
# Save statistics
stats_path = os.path.join(output_dir, "statistics.json")
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Statistics: {stats_path}")
# Save results
results_path = os.path.join(output_dir, "results.json")
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Results: {results_path}")
# Save config
config_path = os.path.join(output_dir, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
log(f"[SAVE] Config: {config_path}")
def main(
config_path: str = "sc/config/experiment_queue_random.yaml",
user_id: int = 5,
shuffle_seed: int = 42,
) -> None:
"""
Run Queue Random baseline experiment.
This baseline maintains queue structure but refreshes ALL elements each step.
Tests whether performance gains come from:
1. Queue structure itself (5 ICL examples)
2. Cumulative learning (retaining good examples over time)
Args:
config_path: Path to YAML configuration file
user_id: User ID to process (5, 10, or 15)
shuffle_seed: Shuffle seed for data order (42 or 123)
Example:
python -m sc.run_sc_queue_random --user_id=5 --shuffle_seed=42
python -m sc.run_sc_queue_random --user_id=15 --shuffle_seed=123
"""
log(f"[MAIN] Loading config: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
# Override with CLI arguments
config["user_id"] = user_id
config["shuffle_seed"] = shuffle_seed
# Create unique log path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
os.makedirs(config["log_path"], exist_ok=True)
# Print experiment info
debug_log.log_policy_main_header(
"QUEUE RANDOM BASELINE",
user_id,
shuffle_seed,
config.get("queue_size", 5),
len(config.get("models", [])),
config["log_path"],
policy_note=" Policy: Queue Random (full refresh every step)",
)
# Load shuffled data
debug_log.log_policy_loading_data(user_id, shuffle_seed)
dataloader = ShuffledDataLoader(
data_path=config["data_path"],
user_id=user_id,
seed=shuffle_seed,
example_pool=config.get("example_pool", "out"),
)
# Load models
debug_log.log_policy_loading_models()
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
num_ctx=config.get("num_ctx", 15000),
)
# Run experiment
debug_log.log_policy_start(label="Queue Random experiment")
results = asyncio.run(run_queue_random_experiment(
dataloader=dataloader,
model_pool=model_pool,
config=config,
user_id=user_id,
shuffle_seed=shuffle_seed,
))
# Compute statistics
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
stats = compute_statistics(results, stages)
# Print final summary
temporal = stats.get("temporal_analysis", {})
debug_log.log_policy_complete_summary(
"QUEUE RANDOM BASELINE",
user_id,
shuffle_seed,
stats,
stats.get("stage_accuracy", {}),
stats.get("stage_sample_counts", {}),
temporal,
expected_note="\n [EXPECTED] Improvement should be ~0 (no cumulative learning)",
)
# Save results
save_results(results, stats, config, user_id, shuffle_seed)
log(f"\n[MAIN] Results saved to: {config['log_path']}")
if __name__ == "__main__":
Fire(main)

View File

@@ -1,473 +0,0 @@
"""
Consistency-based Queue Policy Experiment Runner
This module implements the CONSISTENCY-based queue update policy:
- Queue is updated based on SC consensus/agreement scores
- Higher consistency (more agents agree) examples are retained
- Lower consistency examples are evicted
Consistency Score = (Number of agents agreeing with majority) / (Total agents)
Usage:
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
"""
import os
import sys
import asyncio
import yaml
import json
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Any, Optional
from sklearn.metrics import f1_score, precision_score, recall_score
from fire import Fire
# Add project root to path for relative imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.queue import Queue
from sc.core.agent_pool import AgentPool
from sc import debug_log
log = debug_log.log
async def run_consistency_experiment(
dataloader: ShuffledDataLoader,
model_pool,
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> List[Dict[str, Any]]:
"""
Run consistency-based queue policy experiment.
Queue Update Policy:
- After each inference, calculate per-agent consistency
- Consistency = how many other agents agree with this agent's answer
- Rank queue elements by consistency score
- Evict the lowest consistency element
- Add a new random ICL example set
Args:
dataloader: ShuffledDataLoader with pre-shuffled test data
model_pool: Async model pool for LLM inference
config: Experiment configuration
user_id: User ID being processed
shuffle_seed: Shuffle seed for reproducibility
Returns:
List of result dictionaries
"""
# Set seeds for reproducibility within experiment
random.seed(shuffle_seed)
np.random.seed(shuffle_seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
log(f"[ERROR] No examples found for user {user_id}")
return []
# Build class_indices for Queue initialization
class_indices = {}
for idx, example in enumerate(example_dataset):
label = example["label"]
if label not in class_indices:
class_indices[label] = []
class_indices[label].append(idx)
# Initialize Queue
queue_size = config.get("queue_size", 5)
ex_queue = Queue(class_indices, queue_size)
# Tracking variables
results = []
cumulative_correct = 0
window_size = config.get("tracking_window", 20)
recent_results = []
confidence_history = []
consistency_history = []
all_predictions = []
all_ground_truths = []
debug_log.log_policy_experiment_header(
"CONSISTENCY-BASED",
user_id,
shuffle_seed,
len(dataloader),
queue_size,
)
for processed_count, sample in enumerate(dataloader):
ex_queue.set_current_time(processed_count)
# Log queue state before processing
debug_log.log_policy_queue_state("CONSISTENCY", processed_count, user_id, ex_queue)
# Create agent pool
agent_pool = AgentPool(log_path=config["log_path"])
try:
for queue_idx, ex_idcs in enumerate(ex_queue):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
model_pool=model_pool,
task_info=dataloader.get_task_info(),
classes_info=dataloader.get_classes_info(),
sensor_info=dataloader.get_sensor_info(),
sample=sample,
examples=examples,
log_path=config["log_path"],
)
agent_pool.add_agent(agent)
except Exception as e:
log(f"[ERROR] Failed to create agents: {e}")
import traceback
traceback.print_exc()
continue
if len(agent_pool.agents) == 0:
log(f"[WARN] No agents created for sample {processed_count}")
continue
# Run parallel interpretation
try:
interpretation_result = await agent_pool.run_parallel_interpretation()
except Exception as e:
log(f"[ERROR] Interpretation failed: {e}")
import traceback
traceback.print_exc()
continue
if interpretation_result is None:
log(f"[WARN] Interpretation failed for sample {processed_count}")
continue
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
# Evaluate result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
confidence_history.append(avg_confidence)
consistency_history.append(consistency)
all_predictions.append(answer)
all_ground_truths.append(ground_truth)
# Performance logging
debug_log.log_policy_result(
"CONSISTENCY",
processed_count,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
)
# CONSISTENCY-BASED Queue Update
if responses:
# Calculate per-agent consistency: how many other agents agree
all_answers = [r.get("ANSWER") for r in responses.values()]
consistency_map = {}
for idx, response in responses.items():
agent_answer = response.get("ANSWER")
# Consistency = ratio of agents (including self) that agree
agreement_count = all_answers.count(agent_answer)
agent_consistency = agreement_count / len(all_answers) if all_answers else 0
consistency_map[idx] = agent_consistency
debug_log.log_consistency_map(responses, consistency_map)
# Update queue by consistency (reuses confidence method with consistency scores)
ex_queue.update_by_confidence(consistency_map)
ex_queue.increment_usage(list(responses.keys()))
debug_log.log_policy_queue_state_after("Consistency", processed_count, ex_queue)
# Store result
result = {
"sample_idx": processed_count,
"original_idx": sample.get("original_idx", processed_count),
"shuffled_idx": sample.get("shuffled_idx", processed_count),
"answer": answer,
"ground_truth": ground_truth,
"is_correct": is_correct,
"confidence": avg_confidence,
"consistency": consistency,
"cumulative_accuracy": cumulative_accuracy,
"window_accuracy": window_accuracy,
"experiment_type": "usc",
"user_id": user_id,
"shuffle_seed": shuffle_seed,
}
results.append(result)
# Final statistics
survival_summary = ex_queue.get_survival_summary()
debug_log.log_final_policy_survival("usc", user_id, shuffle_seed, survival_summary)
if results:
results[-1]["queue_survival_stats"] = survival_summary
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
return results
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
"""Compute comprehensive experiment statistics."""
if not results:
return {}
# Overall metrics
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Confidence and consistency averages
confidences = [r.get("confidence", 0) for r in results]
consistencies = [r.get("consistency", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-stage accuracy
stage_correct = {}
stage_total = {}
for r in results:
gt = r.get("ground_truth", "UNKNOWN")
stage_total[gt] = stage_total.get(gt, 0) + 1
if r.get("is_correct", False):
stage_correct[gt] = stage_correct.get(gt, 0) + 1
stage_accuracy = {}
for stage in stages:
if stage in stage_total:
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
else:
stage_accuracy[stage] = 0.0
# F1 Score and Macro metrics
predictions = [r.get("answer", "") for r in results]
ground_truths = [r.get("ground_truth", "") for r in results]
try:
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
except Exception:
macro_f1 = macro_precision = macro_recall = 0.0
# Temporal analysis
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = second_half_acc = accuracy
improvement = 0
# Learning curve (every 10 samples)
learning_curve = []
for i in range(0, len(results), 10):
chunk = results[:i+10]
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
learning_curve.append({
"sample_idx": min(i+10, len(results)),
"cumulative_accuracy": chunk_acc,
})
# Convergence speed (90% of final accuracy)
final_accuracy = accuracy
convergence_threshold = 0.9 * final_accuracy
convergence_idx = None
running_correct = 0
for i, r in enumerate(results):
running_correct += 1 if r.get("is_correct", False) else 0
running_acc = running_correct / (i + 1)
if running_acc >= convergence_threshold:
convergence_idx = i
break
# Queue statistics
queue_stats = {}
for r in results:
if "queue_survival_stats" in r:
queue_stats = r["queue_survival_stats"]
break
return {
"experiment_type": "consistency",
"user_id": results[0].get("user_id") if results else None,
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
"total_samples": total,
"correct": correct,
"accuracy": accuracy,
"avg_confidence": float(avg_confidence),
"avg_consistency": float(avg_consistency),
"stage_accuracy": stage_accuracy,
"stage_sample_counts": stage_total,
"macro_f1": float(macro_f1),
"macro_precision": float(macro_precision),
"macro_recall": float(macro_recall),
"temporal_analysis": {
"first_half_accuracy": first_half_acc,
"second_half_accuracy": second_half_acc,
"improvement": improvement,
},
"learning_curve": learning_curve,
"convergence_idx": convergence_idx,
"queue_stats": queue_stats,
}
def save_results(
results: List[Dict[str, Any]],
stats: Dict[str, Any],
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> None:
"""Save experiment results and statistics."""
log_path = config["log_path"]
# Create user/seed specific directory
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
os.makedirs(output_dir, exist_ok=True)
# Save statistics
stats_path = os.path.join(output_dir, "statistics.json")
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Statistics: {stats_path}")
# Save results
results_path = os.path.join(output_dir, "results.json")
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Results: {results_path}")
# Save config
config_path = os.path.join(output_dir, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
log(f"[SAVE] Config: {config_path}")
def main(
config_path: str = "sc/config/experiment_consistency.yaml",
user_id: int = 5,
shuffle_seed: int = 42,
) -> None:
"""
Run Consistency-based Queue Policy experiment.
Args:
config_path: Path to YAML configuration file
user_id: User ID to process (5 or 10)
shuffle_seed: Shuffle seed for data order (42, 123, or 456)
Example:
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
"""
log(f"[MAIN] Loading config: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
# Override with CLI arguments
config["user_id"] = user_id
config["shuffle_seed"] = shuffle_seed
# Create unique log path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
os.makedirs(config["log_path"], exist_ok=True)
# Print experiment info
debug_log.log_policy_main_header(
"CONSISTENCY-BASED",
user_id,
shuffle_seed,
config.get("queue_size", 5),
len(config.get("models", [])),
config["log_path"],
)
# Load shuffled data
debug_log.log_policy_loading_data(user_id, shuffle_seed)
dataloader = ShuffledDataLoader(
data_path=config["data_path"],
user_id=user_id,
seed=shuffle_seed,
example_pool=config.get("example_pool", "out"),
)
# Load models
debug_log.log_policy_loading_models()
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
num_ctx=config.get("num_ctx", 15000),
)
# Run experiment
debug_log.log_policy_start(label="experiment")
results = asyncio.run(run_consistency_experiment(
dataloader=dataloader,
model_pool=model_pool,
config=config,
user_id=user_id,
shuffle_seed=shuffle_seed,
))
# Compute statistics
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
stats = compute_statistics(results, stages)
# Print final summary
temporal = stats.get("temporal_analysis", {})
debug_log.log_policy_complete_summary(
"CONSISTENCY",
user_id,
shuffle_seed,
stats,
stats.get("stage_accuracy", {}),
stats.get("stage_sample_counts", {}),
temporal,
)
# Save results
save_results(results, stats, config, user_id, shuffle_seed)
log(f"\n[MAIN] Results saved to: {config['log_path']}")
if __name__ == "__main__":
Fire(main)

View File

@@ -1,124 +0,0 @@
"""
Call transformers serve (chat) and our HF API (logprobs) from Python with requests.
Summary:
--------
- transformers serve (transformers chat localhost:8000 --model-name-or-path ...):
- Exposes OpenAI-compatible endpoints: /v1/chat/completions, /v1/responses, /v1/models.
- It does NOT return logits or logprobs; the response chunks have "logprobs": null.
- You can still use it for chat from Python via requests (see chat_with_transformers_serve).
- For answer + logprobs (top-k per token):
- Use our custom API in sc/hf_api.py:
MODEL_DIR=/path/to/model uvicorn sc.hf_api:app --host 0.0.0.0 --port 8000
- Then call POST /generate with requests (see get_logprobs_via_hf_api).
Usage:
------
# Chat only (transformers serve on 8000):
python -c "
from sc.transformers_serve_client_example import chat_with_transformers_serve
print(chat_with_transformers_serve('Hello!'))
"
# Answer + logprobs (sc/hf_api on 8000):
python -c "
from sc.transformers_serve_client_example import get_logprobs_via_hf_api
r = get_logprobs_via_hf_api([{'role':'user','content':'Hello!'}])
print('text:', r['generated_text'])
print('first step top-k:', r['steps'][0]['topk'])
"
"""
from __future__ import annotations
import json
from typing import Any, Dict, List
import requests
# Default base URL for transformers serve (OpenAI-compatible)
TRANSFORMERS_SERVE_URL = "http://localhost:8000/v1"
# Default base URL for our sc/hf_api (generate + logprobs)
HF_API_URL = "http://localhost:8000"
def chat_with_transformers_serve(
user_message: str,
*,
base_url: str = TRANSFORMERS_SERVE_URL,
model: str = "openai/gpt-oss-20b",
max_tokens: int = 256,
stream: bool = False,
) -> str:
"""
Send a chat message to a server running `transformers serve`.
Returns the assistant reply text. No logits/logprobs (server does not provide them).
"""
url = f"{base_url.rstrip('/')}/chat/completions"
payload = {
"model": model,
"messages": [{"role": "user", "content": user_message}],
"max_tokens": max_tokens,
"stream": stream,
}
resp = requests.post(url, json=payload, timeout=60)
resp.raise_for_status()
data = resp.json()
# Non-stream: choices[0].message.content
choices = data.get("choices", [])
if not choices:
return ""
msg = choices[0].get("message", {})
return msg.get("content", "") or ""
def get_logprobs_via_hf_api(
messages: List[Dict[str, str]],
*,
base_url: str = HF_API_URL,
max_new_tokens: int = 64,
top_k: int = 10,
) -> Dict[str, Any]:
"""
Call our sc/hf_api POST /generate endpoint.
Returns generated text and per-token top-k logprobs (no raw logits over the wire).
"""
url = f"{base_url.rstrip('/')}/generate"
payload = {
"messages": messages,
"max_new_tokens": max_new_tokens,
"top_k": top_k,
"temperature": 0.0,
"do_sample": False,
}
resp = requests.post(url, json=payload, timeout=120)
resp.raise_for_status()
return resp.json()
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python -m sc.transformers_serve_client_example chat|logprobs [message]")
print(" chat -> call transformers serve /v1/chat/completions (no logprobs)")
print(" logprobs -> call sc/hf_api /generate (returns top-k logprobs)")
sys.exit(0)
cmd = sys.argv[1].lower()
message = (sys.argv[2] if len(sys.argv) > 2 else "Hello, how are you?").strip()
if cmd == "chat":
text = chat_with_transformers_serve(message)
print("Reply:", text)
elif cmd == "logprobs":
out = get_logprobs_via_hf_api([{"role": "user", "content": message}])
print("Generated:", out.get("generated_text", ""))
print("Steps (first 3):")
for s in out.get("steps", [])[:3]:
print(" token:", repr(s.get("token")), "logprob:", s.get("logprob"), "topk:", [t.get("token") for t in s.get("topk", [])[:5]])
else:
print("Unknown command. Use 'chat' or 'logprobs'.")
sys.exit(1)