5 Commits

Author SHA1 Message Date
Hyungjun Yoon
c4bbbf702d implemented globem processing 2026-03-18 21:30:42 +09:00
Hyungjun Yoon
ccb88d4eef implemented 6 random baselines 2026-03-12 20:41:34 +09:00
Hyungjun Yoon
95ddd935f7 fixed some bugs in run 2026-03-10 19:54:37 +09:00
Hyungjun Yoon
d10b70dc20 cleaned code 2026-03-09 00:04:01 +09:00
Hyungjun Yoon
3f3db31d25 working code with queue update 2026-03-08 23:55:32 +09:00
98 changed files with 2000 additions and 16414 deletions

View File

@@ -1,198 +0,0 @@
# Project Memory: tsllm_personalization_icl
## Project Overview
Research project on **personalized time-series classification using LLMs with In-Context Learning (ICL)**. The core idea is to use LLMs (via Ollama or HuggingFace) to classify sensor/physiological time-series data (e.g., EEG sleep stages, PPG blood pressure), where ICL examples are selected via different strategies (random vs. similarity-based) to study personalization.
Current branch: `case` (main branch: `main`)
---
## Repository Structure
```
tsllm_personalization_icl/
├── run.py # Main entry point (ICL experiment runner)
├── config/
│ └── sleepedf.yaml # Config: data path, models, selection criteria, log path
├── core/ # Base pipeline (ICL experiments)
│ ├── model.py # Model/AsyncModelPool (Ollama or LangChain init_chat_model)
│ ├── agent.py # Base Agent class (memory, logging, JSON parsing)
│ ├── sensing_agent.py # SensingAgent: ICL classify + evaluate + reflect
│ ├── data_loader.py # DataLoader: test/example splits + similarity-based selection
│ └── embedding_index.py # EmbeddingIndex: cosine similarity over Chronos-2 embeddings
├── sc/ # Self-Consistency (SC) variant pipeline
│ ├── run_sc.py # SC experiment runner (main entry: `python -m sc.run_sc config.yaml`)
│ ├── run_confidence_based.py
│ ├── run_consistency_based.py
│ ├── run_sc_queue_random.py
│ ├── run_usc.py
│ ├── core/
│ │ ├── model.py # HuggingFace CausalLM wrapper (returns text + logits)
│ │ ├── agent.py # SC base agent
│ │ ├── scagent.py # SCAgent: single-pass interpret() with REASON/CONFIDENCE/ANSWER
│ │ ├── agent_pool.py # AgentPool: parallel interpret + majority voting
│ │ ├── example_queue.py # Queue: priority queue of example sets, updated by confidence
│ │ ├── data_loader.py # DataLoader + InMemoryDataLoader + PPGBPLoader
│ │ ├── judge_agent.py # JudgeAgent
│ │ ├── majority_voting.py # MajorityVoting utilities
│ │ └── model_utils.py
│ ├── analysis/
│ │ └── analyze_sc_results.py
│ ├── preprocess/
│ │ └── shuffle_data.py
│ ├── debug_log.py # Structured debug logging helpers
│ ├── logger.py
│ ├── hf_api.py
│ └── ollama_test.py
├── analysis/
│ ├── ppgbp_loader.py
│ ├── user_similarity/ # Embedding analysis scripts
│ │ ├── chronos2/ # Chronos-2 embeddings for SleepEDF
│ │ ├── chronos2_ppgbp/ # Chronos-2 embeddings for PPGBP
│ │ ├── labram/ # LaBraM embeddings
│ │ ├── sbert/
│ │ ├── sbert_metadata/
│ │ └── sbert_metadata_ppgbp/ # SBERT metadata embeddings for PPGBP
│ ├── analyze_data.ipynb
│ └── analyze_preliminary.ipynb
├── preprocess/
│ ├── dhedfreader.py
│ └── preprocess_SleepEDF.py
├── utils/
│ ├── kill_ollamas.sh
│ └── launch_ollamas.sh
└── requirements.txt
```
---
## Key Concepts
### Datasets
- **SleepEDF**: EEG-based sleep stage classification (W, N1, N2, N3, REM).
- **PPGBP**: PPG-based blood pressure prediction.
- Data format: `<data_path>/<user_id>/1/` (train/examples, HuggingFace dataset on disk) and `<data_path>/<user_id>/2/` (test split). Metadata in `info.json` (keys: `task`, `class`, `feature`).
### ICL Selection Strategies (core/data_loader.py)
- `out_random`: Random examples from OTHER users (cross-user, random)
- `in_random`: Random examples from SAME user (personalized, random)
- `out_similar`: Most similar examples from OTHER users (Chronos-2 embedding similarity)
- `in_similar`: Most similar examples from SAME user (embedding similarity)
Similarity uses cosine similarity over Chronos-2 embeddings, one example per class (balanced).
### Models
- **core/model.py**: Supports Ollama (local/remote via `ollama:url:<host>/<model>`) and any LangChain-supported model (OpenAI, Together, etc.)
- **sc/core/model.py**: HuggingFace CausalLM loaded via `transformers`, returns text + logits. Used for SC experiments.
### Self-Consistency (sc/) Pipeline
The SC pipeline uses a **priority queue of example sets**:
- `Queue` (sc/core/example_queue.py): holds `capacity` example-sets (one set = one example per class). Updated each step by agent confidence scores — highest-confidence sets are kept, lowest evicted and replaced with a random new set.
- `AgentPool` (sc/core/agent_pool.py): runs one `SCAgent` per queue slot in parallel, aggregates via majority vote, tracks confidence and consistency.
- `SCAgent` (sc/core/scagent.py): single LLM call, returns `{REASON, CONFIDENCE, ANSWER}` JSON.
### Base Agent (core/agent.py)
- Three memory tiers: `long_term_memory`, `short_term_memory`, `volatile_memory`
- `invoke()`: async, appends messages to memory, calls model pool
- `safe_parse_json()` / `safe_parse_json_list()`: robust JSON parsing with cleanup
- Token counting via `tiktoken` (gpt-3.5-turbo encoding as proxy)
### SensingAgent (core/sensing_agent.py)
Extends Agent. Methods:
- `solve()`: classify a sample with ICL examples → `{REASON, ANSWER}`
- `interpret()`: same as solve but without logging ground truth
- `evaluate()`: evaluate another agent's answer (for multi-agent debate)
- `reflect()`: refine answer based on peer evaluations
---
## Running Experiments
### Basic ICL run
```bash
python run.py run config/sleepedf.yaml
```
### Compare multiple selection criteria
```bash
python run.py compare config/sleepedf.yaml \
--criteria_list="out_random,in_random,out_similar,in_similar" \
--embedding_path="./embeddings_full"
```
### Self-Consistency run
```bash
python -m sc.run_sc sc/config/sleepedf_sc.yaml
```
---
## Config (sleepedf.yaml) Key Fields
- `data_path`: root of dataset (contains `info.json` and per-user dirs)
- `log_path`: output directory for results and logs
- `num_seeds`: number of random seeds to run
- `num_examples`: ICL examples per class (default 1)
- `selection_criteria`: `out_random` | `in_random` | `out_similar` | `in_similar`
- `embedding_path`: path to pre-computed Chronos-2 embeddings (required for `*_similar`)
- `models`: list of model specs (Ollama URLs or model IDs)
In SC configs, additional fields:
- `queue_size`: number of example sets to maintain in the priority queue
- `temperature`: LLM sampling temperature
- `max_new_tokens`: max tokens for generation
- `example_pool`: `"out"` (other users) or `"in"` (same user)
- `continuous`: whether queue persists across test samples
---
## Data Format Details
Each HuggingFace dataset sample:
```python
{
"user_id": str,
"session_id": str, # "1" = train, "2" = test
"idx": int,
"label": str, # class name from info.json["class"]
"features": dict, # str -> float/str (sensor features formatted for prompt)
"data": dict, # optional raw data
}
```
`info.json`:
```json
{
"task": "Sleep stage classification from EEG",
"class": {"W": "Wake", "N1": "...", ...},
"feature": "EEG channel description for the LLM..."
}
```
---
## PPGBP Dataset (sc/core/data_loader.py)
`PPGBPLoader`: reads xlsx metadata + signal `.txt` files from `0_subject/`. 80/20 train/test split at subject level. Metadata embeddings (SBERT) auto-generated and cached per subject as `.npy` files under `PPGBP_METADATA_EMBEDDINGS_ROOT` (set via env var or `.env`).
---
## Dependencies (requirements.txt)
- `langchain`, `langchain_ollama`, `langchain_openai`, `langchain_together` — LLM backends
- `datasets` — HuggingFace datasets (on-disk storage)
- `chronos` — Chronos-2 time-series embeddings
- `sentence_transformers` — SBERT for metadata embeddings
- `transformers`, `torch` — HuggingFace model loading (SC pipeline)
- `tiktoken` — token counting
- `fire` — CLI argument parsing
- `mne`, `neurokit2` — EEG/biosignal preprocessing
- `numpy`, `pandas`, `scikit_learn`, `scipy`, `matplotlib`
---
## Notes / Patterns
- Experiments sample every 10th test item (`if idx % 10 != 0: continue` in `run.py`).
- SC runner currently hardcoded to first user only for testing (`users[:1]`).
- Log structure: `<log_path>/<user_id>/<sample_idx>/<seed>/` with `summary.txt`, `log.txt`, `tokens.txt`.
- Embedding index uses cosine similarity (L2-normalized dot product), filtered by user and session.
- `DataLoader` in `sc/` is simpler (no similarity selection); similarity lives in `core/`.
- `InMemoryDataLoader` and `prepare_dataset_for_sc()` allow integrating new datasets without writing to disk.

View File

1
.gitignore vendored
View File

@@ -3,6 +3,7 @@
*.csv
*.arrow
*.json
temp*/
# Byte-compiled / optimized / DLL files
__pycache__/

80
README.md Normal file
View File

@@ -0,0 +1,80 @@
## Setup
```bash
conda create -n tsllmpers python=3.12
conda activate tsllmpers
python -m pip install -r requirements.txt
```
## Preprocessing
Raw data must be preprocessed into HuggingFace `datasets` format before running.
A preprocessing script is provided for the SleepEDF dataset:
```bash
python preprocess/preprocess_SleepEDF.py \
--path /path/to/SleepEDF/raw/sleep-cassette/ \
--out_dir /path/to/output/processed_SleepEDF \
--num_workers 32
```
The output directory will have the following structure:
```
SleepEDF_new/
task_metadata.json # task description, class definitions, data/feature info
user_metadata.json # per-user metadata (age, sex)
00/ # user folder (HuggingFace Dataset saved with save_to_disk)
01/
...
```
Each user dataset contains the columns: `user_id` (str), `label` (str), `features` (dict of floats), and `data` (dict of raw signal arrays).
To preprocess a different dataset, write a similar script that produces the same output structure.
## Running
```bash
python run.py --config_path config/test.yaml
```
## Config
All configuration is in a single YAML file. Example (`config/test.yaml`):
```yaml
log_path: ./temp
data_path: /path/to/data/processed_SleepEDF
target_user: "00"
queue_size: 5
num_shot: 1
model_paths:
- ollama:url:hostname:11437/gpt-oss:20b
- ollama:url:hostname:11438/gpt-oss:20b
vocab_size: 200064
```
| Key | Description |
|---|---|
| `log_path` | Directory for logs and results (timestamped subfolder created automatically) |
| `data_path` | Path to the preprocessed dataset directory |
| `target_user` | User ID to evaluate on; all other users become source (ICL examples) |
| `queue_size` | Number of example sets to maintain in the queue |
| `num_shot` | Number of examples per class in each example set |
| `model_paths` | List of Ollama model endpoints in `ollama:url:host:port/model` format |
| `vocab_size` | Vocabulary size of the model (used for self-certainty scoring) |
## Ollama
The model backend is [Ollama](https://ollama.com). You need one or more Ollama servers running and accessible over HTTP.
Example scripts for managing multiple Ollama instances on a multi-GPU machine are provided in `utils/` for reference. These are environment-specific -- adapt the ports, model paths, and GPU assignments to your own setup:
```bash
# Reference only -- edit before using
bash utils/launch_ollamas.sh # start servers in tmux sessions
bash utils/kill_ollamas.sh # stop all servers
```
The model path format in the config is `ollama:url:<host>:<port>/<model_name>`.

File diff suppressed because one or more lines are too long

View File

@@ -1,96 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "a0874e1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Correct: 1038, Total: 2149, Accuracy: 0.4830153559795254\n"
]
}
],
"source": [
"from glob import glob\n",
"import os\n",
"\n",
"correct = 0\n",
"total = 0\n",
"summary_paths = glob(\"/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF/*/*/*/summary.txt\")\n",
"for summary_path in summary_paths:\n",
" with open(summary_path, \"r\") as f:\n",
" summary = f.read()\n",
" if summary:\n",
" answer = summary.split(\"Answer: \")[-1].split(\" (Ground truth: \")[0]\n",
" ground_truth = summary.split(\" (Ground truth: \")[-1].split(\")\")[0]\n",
" if answer == ground_truth:\n",
" correct += 1\n",
" total += 1\n",
"print(f\"Correct: {correct}, Total: {total}, Accuracy: {correct/total}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f78ffc6f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Correct: 932, Total: 2260, Accuracy: 0.41238938053097346\n"
]
}
],
"source": [
"correct = 0\n",
"total = 0\n",
"summary_paths = glob(\"/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF_out_random/*/*/*/summary.txt\")\n",
"for summary_path in summary_paths:\n",
" with open(summary_path, \"r\") as f:\n",
" summary = f.read()\n",
" if summary:\n",
" answer = summary.split(\"Answer: \")[-1].split(\" (Ground truth: \")[0]\n",
" ground_truth = summary.split(\" (Ground truth: \")[-1].split(\")\")[0]\n",
" if answer == ground_truth:\n",
" correct += 1\n",
" total += 1\n",
"print(f\"Correct: {correct}, Total: {total}, Accuracy: {correct/total}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6872381f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tsllmpers",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,235 +0,0 @@
"""
PPGBP Dataset Loader
A data loader/iterator for the PPG-BP dataset that reads xlsx metadata
and corresponding signal text files, with optional batching support.
"""
import os
import numpy as np
import pandas as pd
from typing import Iterator, Tuple, Dict, List, Optional, Union
class PPGBPLoader:
"""
Data loader for the PPG-BP dataset.
Reads metadata from xlsx file and corresponding PPG signal text files.
Supports iteration over single samples or batches.
Subject list is built from 0_subject/ and matched with metadata;
train/test is 80/20 at subject level (random, fixed by seed).
"""
COLUMN_MAPPING = {
"Sex(M/F)": "sex",
"Age(year)": "age",
"Systolic Blood Pressure(mmHg)": "sysbp",
"Diastolic Blood Pressure(mmHg)": "diasbp",
"Heart Rate(b/m)": "hr",
"BMI(kg/m^2)": "bmi"
}
def __init__(
self,
base_dir: str,
split: str = 'all',
batch_size: Optional[int] = None,
shuffle: bool = False,
num_segments: int = 3,
seed: int = 42
):
self.base_dir = base_dir
self.split = split
self.batch_size = batch_size
self.shuffle = shuffle
self.num_segments = num_segments
self.seed = seed
self.rng = np.random.default_rng(seed)
self.signal_dir = os.path.join(base_dir, "0_subject")
self.xlsx_path = self._find_xlsx_file()
self.metadata = self._load_metadata()
self._all_subject_ids = self._get_unified_subject_ids()
self._train_ids, self._test_ids = self._get_train_test_split(self._all_subject_ids)
self.subject_ids = self._get_split_ids()
self.metadata = self.metadata[self.metadata['subject_ID'].isin(self.subject_ids)]
self._current_idx = 0
self._indices = np.arange(len(self.subject_ids))
if self.shuffle:
self.rng.shuffle(self._indices)
def _find_xlsx_file(self) -> str:
for f in os.listdir(self.base_dir):
if f.endswith('.xlsx'):
return os.path.join(self.base_dir, f)
raise FileNotFoundError(f"No xlsx file found in {self.base_dir}")
def _load_metadata(self) -> pd.DataFrame:
# Use header=1: xlsx has a title row, then row with column names (subject_ID, etc.)
df = pd.read_excel(self.xlsx_path, header=1)
df = df.rename(columns=self.COLUMN_MAPPING)
df = df.fillna(0)
return df
def _get_unified_subject_ids(self) -> np.ndarray:
"""Subject IDs present in both 0_subject/ (from listdir) and metadata."""
if not os.path.isdir(self.signal_dir):
return np.array([], dtype=np.int64)
# Files are named {subject_id}_{segment}.txt
ids_from_files = set()
for f in os.listdir(self.signal_dir):
if f.endswith(".txt"):
try:
sid = int(f.replace(".txt", "").split("_")[0])
ids_from_files.add(sid)
except (ValueError, IndexError):
continue
meta_ids = set(self.metadata["subject_ID"].astype(int).values)
unified = sorted(ids_from_files & meta_ids)
return np.array(unified)
def _get_train_test_split(self, ids: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""80% train, 20% test at subject level; shuffle fixed by seed."""
if len(ids) == 0:
return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
perm = self.rng.permutation(len(ids))
ids_perm = ids[perm]
n_train = max(1, int(round(0.8 * len(ids))))
train_ids = ids_perm[:n_train]
test_ids = ids_perm[n_train:]
return train_ids, test_ids
def _get_split_ids(self) -> np.ndarray:
if self.split == "train":
return self._train_ids
elif self.split == "test":
return self._test_ids
elif self.split == "all":
return self._all_subject_ids
else:
raise ValueError(f"Unknown split: {self.split}. Use 'train', 'test', or 'all'.")
@property
def train_ids(self) -> np.ndarray:
"""Subject IDs in the train split (80%)."""
return self._train_ids
@property
def test_ids(self) -> np.ndarray:
"""Subject IDs in the test split (20%)."""
return self._test_ids
def _load_signal(self, subject_id: int) -> np.ndarray:
segments = []
for s in range(1, self.num_segments + 1):
filepath = os.path.join(self.signal_dir, f"{subject_id}_{s}.txt")
signal = pd.read_csv(filepath, sep='\t', header=None)
signal = signal.values.squeeze()
if len(signal) > 1:
signal = signal[:-1]
segments.append(signal)
return np.array(segments, dtype=object)
def _get_metadata_dict(self, subject_id: int) -> Dict:
row = self.metadata[self.metadata['subject_ID'] == subject_id].iloc[0]
return row.to_dict()
def __len__(self) -> int:
return len(self.subject_ids)
def __iter__(self) -> 'PPGBPLoader':
self._current_idx = 0
if self.shuffle:
self.rng.shuffle(self._indices)
return self
def __next__(self) -> Union[Tuple[Dict, np.ndarray], Tuple[List[Dict], List[np.ndarray]]]:
if self._current_idx >= len(self.subject_ids):
raise StopIteration
if self.batch_size is None:
idx = self._indices[self._current_idx]
subject_id = self.subject_ids[idx]
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
self._current_idx += 1
return metadata, signal
else:
end_idx = min(self._current_idx + self.batch_size, len(self.subject_ids))
batch_indices = self._indices[self._current_idx:end_idx]
metadata_list = []
signal_list = []
for idx in batch_indices:
subject_id = self.subject_ids[idx]
metadata_list.append(self._get_metadata_dict(subject_id))
signal_list.append(self._load_signal(subject_id))
self._current_idx = end_idx
return metadata_list, signal_list
def __getitem__(self, idx: int) -> Tuple[Dict, np.ndarray]:
if idx < 0 or idx >= len(self.subject_ids):
raise IndexError(f"Index {idx} out of range")
subject_id = self.subject_ids[idx]
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
return metadata, signal
def get_by_subject_id(self, subject_id: int) -> Tuple[Dict, np.ndarray]:
if subject_id not in self.subject_ids:
raise ValueError(f"Subject ID {subject_id} not found in this split.")
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
return metadata, signal
def reset(self):
self._current_idx = 0
if self.shuffle:
self.rng.shuffle(self._indices)
def iter_batches(self, batch_size: int) -> Iterator[Tuple[List[Dict], List[np.ndarray]]]:
indices = np.arange(len(self.subject_ids))
if self.shuffle:
self.rng.shuffle(indices)
for start_idx in range(0, len(self.subject_ids), batch_size):
end_idx = min(start_idx + batch_size, len(self.subject_ids))
batch_indices = indices[start_idx:end_idx]
metadata_list = []
signal_list = []
for idx in batch_indices:
subject_id = self.subject_ids[idx]
metadata_list.append(self._get_metadata_dict(subject_id))
signal_list.append(self._load_signal(subject_id))
yield metadata_list, signal_list
def get_loaders(
base_dir: str,
batch_size: Optional[int] = None,
shuffle_train: bool = True,
seed: int = 42
) -> Tuple[PPGBPLoader, PPGBPLoader]:
"""Convenience function to get train and test loaders (80/20 split)."""
train_loader = PPGBPLoader(
base_dir=base_dir, split="train",
batch_size=batch_size, shuffle=shuffle_train, seed=seed
)
test_loader = PPGBPLoader(
base_dir=base_dir, split="test",
batch_size=batch_size, shuffle=False, seed=seed
)
return train_loader, test_loader
if __name__ == "__main__":
print("PPGBPLoader ready to use.")

View File

@@ -1,580 +0,0 @@
"""
Chronos-2 Time Series Embedding Extraction and Visualization Pipeline
This module provides functionality to
1. Extract embeddings from multivariate time series using Chronos-2 foundation model
2. Visualize embeddings using dimensionality reduction (t-SNE)
Chronos-2 Overview:
Chronos-2 is a time series foundation model developed by Amazon that converts
time series forecasting into a language modeling task. It tokenizes time series
values and generates probabilistic forecasts using a transformer architecture.
Key Features:
- Zero-shot forecasting: Works on unseen time series without fine-tuning
- Probabilistic predictions: Outputs quantile forecasts (e.g., 10%, 50%, 90%)
- Multivariate support: Can process multiple channels simultaneously
Embedding Strategy:
We use Chronos-2's internal encoder hidden states as embeddings, which is the
recommended approach for representation learning. The encoder captures rich
temporal patterns through self-attention mechanisms.
"encoder" : Uses encoder hidden states directly
- More informative representation of input characteristics
- Captures learned temporal patterns from pre-training
Usage:
# Extract embeddings
python gen_plot.py extract --data_root /path/to/data --out_dir ./embeddings
# Visualize with t-SNE
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
import sys
from glob import glob
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt
import torch
from datasets import load_from_disk, Dataset
from chronos import BaseChronosPipeline, Chronos2Pipeline
from fire import Fire
from sklearn.manifold import TSNE
# Add parent directories to path for importing metadata loader
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../.."))
from preprocess.sleepedf_metadata import SleepEDFMetadata, DEFAULT_METADATA_PATH
# pipeline: Chronos2Pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cuda")
# =============================================================================
# Constants
# =============================================================================
# EEG channel names used in Sleep-EDF dataset
# Fpz-Cz: Frontal-Central electrode pair
# Pz-Oz: Parietal-Occipital electrode pair
EEG_CHANNEL_1 = "EEG Fpz-Cz"
EEG_CHANNEL_2 = "EEG Pz-Oz"
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class Chronos_2_Embedder:
"""
Extracts fixed-dimensional embeddings from multivariate time series using Chronos-2.
Uses the internal encoder hidden states from Chronos-2's transformer.
This directly accesses the model's learned features for representation learning.
Architecture:
Input Time Series → Patching → Encoder (6 layers) → Hidden States → Pooling → Embedding
Processing Pipeline:
1. Instance normalization (z-score per series)
2. Patching (splits series into fixed-size patches)
3. Patch embedding (linear projection to d_model dimensions)
4. Transformer encoder (6 layers of self-attention)
5. Output: hidden states of shape (n_variates, num_patches + 2, d_model)
- +2 for [REG] token and masked output patch token
Pooling Strategies:
- mean: Average across all patch tokens (excluding special tokens)
- cls: Use the [REG] token embedding (similar to BERT's [CLS])
Attributes:
pipeline: Chronos-2 model pipeline for inference
pooling_strategy: How to pool encoder hidden states ("mean" or "cls")
"""
def __init__(
self,
model_name: str = "amazon/chronos-2",
device_map: str = None,
pooling_strategy: str = "mean",
variate_fusion: str = "concat",
):
"""
Initialize the Chronos-2 embedder.
Args:
model_name: HuggingFace model name or local path
device_map: Device placement ("cuda", "cpu", or None for auto)
pooling_strategy: "mean" (average all patches) or "cls" (use [REG] token only)
"""
self.pooling_strategy = pooling_strategy
if variate_fusion not in ["concat", "mean"]:
raise ValueError(
f"Invalid variate_fusion: {variate_fusion}. Use 'concat' (1024) or 'mean' (512)."
)
self.variate_fusion = variate_fusion
if device_map is None:
device_map = "cuda" if torch.cuda.is_available() else "cpu"
self.device_map = device_map
self.device = torch.device(
"cuda" if device_map.startswith("cuda") and torch.cuda.is_available() else "cpu"
)
# Load pre-trained Chronos-2 model
print(f"[INFO] Loading Chronos-2 model: {model_name}")
print(f"[INFO] Device: {device_map}")
print(f"[INFO] Pooling strategy: {pooling_strategy}")
print(f"[INFO] Variate fusion: {variate_fusion}")
self.pipeline: Chronos2Pipeline = Chronos2Pipeline.from_pretrained(
model_name,
device_map=device_map
)
@staticmethod
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
"""
Discover all user/session directories under data_root.
Uses glob pattern matching for cleaner directory traversal.
Expected structure: data_root/user_id/session_id/
Returns:
List of (user_id, session_id, session_path) tuples
"""
discovered_paths = []
# Use glob to find all session directories (2 levels deep)
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
if not os.path.isdir(session_path):
continue
# Extract user_id and session_id from path
session_id = os.path.basename(session_path)
user_id = os.path.basename(os.path.dirname(session_path))
discovered_paths.append((user_id, session_id, session_path))
return discovered_paths
@torch.no_grad()
def compute_embedding(self, batch: Dict[str, Any]) -> np.ndarray:
"""
Generate embedding vectors using Chronos-2's internal encoder hidden states.
Processing Pipeline:
1. Parse batch to extract 2-channel EEG time series
2. Format for Chronos-2 input (B, V, L)
3. Call pipeline.embed() to get encoder hidden states
4. Pool hidden states to get fixed-size embedding
Args:
batch: HuggingFace dataset batch from slicing (dataset[start:end])
Format: {"data": [{"EEG Fpz-Cz": [...], "EEG Pz-Oz": [...]}, ...]}
Returns:
Embedding array of shape (batch_size, embedding_dim)
For Chronos-2 with d_model=512:
- pooling (mean/cls) produces (B, 512) per variate
- variate_fusion='concat': (B, 2*512) = (B, 1024) for 2 channels
- variate_fusion='mean' : (B, 512) by averaging across variates
"""
# =====================================================================
# Step 1: Parse batch to multivariate time series array
# =====================================================================
samples = batch["data"]
channel_1 = np.stack([np.asarray(s[EEG_CHANNEL_1], dtype=np.float32) for s in samples])
channel_2 = np.stack([np.asarray(s[EEG_CHANNEL_2], dtype=np.float32) for s in samples])
timeseries = np.stack([channel_1, channel_2], axis=-1) # (B, 3000, 2)
# =====================================================================
# Step 2: Format for Chronos-2 input
# =====================================================================
# Chronos-2 expects (batch, n_variates, seq_length)
x_input = np.transpose(timeseries, (0, 2, 1)).astype(np.float32) # (B, 2, 3000)
# =====================================================================
# Step 3: Get encoder embeddings using pipeline.embed()
# =====================================================================
# embed() returns:
# - embeddings: list of tensors, each (n_variates, num_patches + 2, d_model)
# - loc_scale: list of tuples (loc, scale) for denormalization
embeddings_list, loc_scale_list = self.pipeline.embed(x_input)
# =====================================================================
# Step 4: Pool hidden states to get fixed-size embedding
# =====================================================================
all_embeddings = []
for emb in embeddings_list:
# emb shape: (n_variates, num_patches + 2, d_model) = (2, N+2, 512)
if self.pooling_strategy == "cls":
# Use the [REG] token (first token) as the embedding
# This is similar to BERT's [CLS] token approach
pooled = emb[:, 0, :] # (n_variates, d_model) = (2, 512)
else: # "mean" pooling (default)
# Average across all patch tokens (excluding special tokens)
# Skip first token ([REG]) and last token (masked output patch)
pooled = emb[:, 1:-1, :].mean(dim=1) # (n_variates, d_model) = (2, 512)
if self.variate_fusion == "mean":
# Fuse variates by averaging their pooled representations
# (2, 512) -> (512,)
pooled_flat = pooled.mean(dim=0)
else:
# Keep each variate's representation and concatenate
# (2, 512) -> (1024,)
pooled_flat = pooled.reshape(-1)
all_embeddings.append(pooled_flat.cpu().numpy())
return np.stack(all_embeddings, axis=0).astype(np.float32)
def extract_embeddings(
self,
data_root: str,
batch_size: int = 32,
metadata_path: str = DEFAULT_METADATA_PATH,
) -> Dataset:
"""
Extract embeddings from all sessions under the data root directory.
Iterates through all user/session combinations, processes time series
in batches, and aggregates results with metadata.
Args:
data_root: Root directory containing user/session subfolders
batch_size: Number of samples to process together.
Larger = faster but more memory.
32 is a good balance for most GPUs.
metadata_path: Path to SC-subjects.xls for gender/age info
Returns:
HuggingFace Dataset with columns:
- user_id, session_id, idx, label (metadata)
- gender, age (demographic metadata)
- embedding (vector; dim depends on variate_fusion: 1024 for concat, 512 for mean)
"""
session_paths = self.discover_session_paths(data_root)
print(f"[INFO] Discovered {len(session_paths)} sessions")
# Load metadata for gender/age information
try:
metadata = SleepEDFMetadata(metadata_path)
has_metadata = True
except FileNotFoundError:
print(f"[WARNING] Metadata file not found: {metadata_path}")
print(f"[WARNING] Gender/age information will not be available")
metadata = None
has_metadata = False
all_embeddings = []
all_user_ids = []
all_session_ids = []
all_idxs = []
all_labels = []
all_genders = []
all_ages = []
for user_id, session_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# shuffle dataset
dataset = dataset.shuffle(seed=0)
num_samples = len(dataset)
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Process in batches to manage memory
for batch_start in range(0, num_samples, batch_size):
batch_end = min(batch_start + batch_size, num_samples)
# Slice dataset to get batch
batch = dataset[batch_start:batch_end]
# Compute embeddings
embeddings = self.compute_embedding(batch)
# Collect embeddings and metadata
for i in range(embeddings.shape[0]):
user_id_str = str(batch["user_id"][i])
all_embeddings.append(embeddings[i].tolist())
all_user_ids.append(user_id_str)
all_session_ids.append(str(batch["session_id"][i]))
all_idxs.append(int(batch["idx"][i]))
all_labels.append(str(batch["label"][i]))
# Add demographic metadata
if has_metadata:
info = metadata.get_info(user_id_str)
if info:
all_genders.append(info['gender'])
all_ages.append(info['age'])
else:
all_genders.append('Unknown')
all_ages.append(-1)
else:
all_genders.append('Unknown')
all_ages.append(-1)
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
"session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"gender": all_genders,
"age": all_ages,
"embedding": all_embeddings,
})
print(f"[INFO] Total samples: {len(result_dataset)}")
if has_metadata:
gender_counts = {}
for g in all_genders:
gender_counts[g] = gender_counts.get(g, 0) + 1
print(f"[INFO] Gender distribution in extracted data: {gender_counts}")
return result_dataset
def save_embeddings(
self,
dataset: Dataset,
output_dir: str
) -> None:
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
dimensionality reduction technique that preserves local structure.
Points that are similar in high dimensions stay close in 2D.
Args:
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity parameter (typically 5-50).
Higher values consider more neighbors, creating smoother layouts.
Rule of thumb: perplexity ~ sqrt(num_samples)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0, # For reproducibility
perplexity=perplexity,
max_iter=1000, # Usually sufficient for convergence
init='random',
learning_rate='auto', # Let sklearn choose optimal learning rate
)
return tsne.fit_transform(embeddings)
def create_scatter_plot(
coordinates: np.ndarray,
labels: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot with categorical coloring.
Args:
coordinates: 2D array of shape (num_points, 2)
labels: Category labels for each point (string array)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Get unique labels for legend
unique_labels = sorted(set(labels))
# Select colormap based on number of categories
# tab10: 10 distinct colors, tab20: 20 distinct colors
colormap = plt.cm.tab10 if len(unique_labels) <= 10 else plt.cm.tab20
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for proper legend
for idx, label in enumerate(unique_labels):
mask = labels == label
ax.scatter(
coordinates[mask, 0],
coordinates[mask, 1],
c=[colormap(idx % 20)],
s=15,
label=label,
alpha=0.7,
)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(loc='best', markerscale=2)
# Save figure as vector PDF (scalable, ideal for publications)
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
def extract(
self,
data_root: str,
out_dir: str,
model: str = "amazon/chronos-2",
batch_size: int = 32,
pooling: str = "mean",
variate_fusion: str = "concat",
) -> None:
"""
Extract embeddings from time series data.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for HuggingFace dataset
model: Chronos-2 model name or path (default: amazon/chronos-2)
batch_size: Batch size for inference (default: 32)
pooling: Pooling strategy - 'mean' or 'cls' (default: mean)
variate_fusion: How to combine channels - 'concat' (1024) or 'mean' (512)
"""
# Validate pooling argument
if pooling not in ["mean", "cls"]:
raise ValueError(f"Invalid pooling strategy: {pooling}. Use 'mean' or 'cls'.")
if variate_fusion not in ["concat", "mean"]:
raise ValueError(
f"Invalid variate_fusion: {variate_fusion}. Use 'concat' (1024) or 'mean' (512)."
)
# Initialize embedder
embedder = Chronos_2_Embedder(
model_name=model,
pooling_strategy=pooling,
variate_fusion=variate_fusion,
)
# Extract and save embeddings
dataset = embedder.extract_embeddings(data_root, batch_size)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str,
out_dir: str,
perplexity: float = 30.0,
users: str = None,
num_users: int = 0,
labels: str = None,
gender: str = None,
metadata_path: str = DEFAULT_METADATA_PATH,
) -> None:
"""
Visualize embeddings with t-SNE.
Args:
emb_dir: Directory containing the HuggingFace embeddings dataset
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 30.0)
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
gender: Filter by gender ('M', 'F', or None for all)
metadata_path: Path to metadata file for gender lookup (if not in dataset)
"""
os.makedirs(out_dir, exist_ok=True)
# Load saved embeddings dataset
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by gender
if gender:
dataset = dataset.filter(lambda x: x.get("gender", "Unknown") == gender)
print(f"[INFO] Filtered to gender: {gender}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Print gender distribution
if "gender" in dataset.column_names:
gender_counts = {}
for g in dataset["gender"]:
gender_counts[g] = gender_counts.get(g, 0) + 1
print(f"[INFO] Gender distribution: {gender_counts}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations
create_scatter_plot(
coordinates_2d,
np.array(dataset["label"]),
"t-SNE Visualization (Colored by Sleep Stage)",
os.path.join(out_dir, "tsne_by_label.pdf")
)
create_scatter_plot(
coordinates_2d,
np.array(dataset["user_id"]),
"t-SNE Visualization (Colored by User ID)",
os.path.join(out_dir, "tsne_by_user.pdf")
)
if __name__ == "__main__":
Fire(CLI)

View File

@@ -1,507 +0,0 @@
"""
User Classification from Time Series Embeddings
This module evaluates how well time series embeddings capture user-specific patterns
by training a simple linear classifier to predict user identity from embeddings.
Motivation:
If embeddings contain user-distinguishing information, a classifier should be able
to predict which user a time series belongs to. High accuracy suggests that the
embeddings capture individual characteristics.
Experimental Design:
- Task: Multi-class classification
- Model: Random Forest with ensemble of decision trees
- Captures non-linear relationships in embedding space
- Provides feature importance scores for interpretability
- Robust to overfitting through bagging and random feature selection
- Split Strategy:
1. Session-based: Train on session 1, test on session 2
2. Random: Standard train/test split
Session-based split is more challenging and realistic because:
- Tests whether user patterns are stable across different recording sessions
- Avoids data leakage from same-session samples in train and test
Random Forest Advantages over Logistic Regression:
- Handles non-linear decision boundaries
- Feature importance reveals which embedding dimensions matter most
- No assumption about data distribution
- Naturally handles multi-class classification
Output:
- Classification metrics
- Confusion matrix visualization
Usage:
python simple_user_classifier.py \\
--embeddings ./embeddings \\
--out_dir ./results
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from typing import Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset
from fire import Fire
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score,
f1_score,
classification_report,
confusion_matrix,
silhouette_score,
)
from sklearn.model_selection import train_test_split
# =============================================================================
# Data Loading
# =============================================================================
def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
if not os.path.isdir(embedding_path):
raise FileNotFoundError(
f"Dataset directory not found: {embedding_path}. "
"Ensure this is a valid HuggingFace dataset directory."
)
# Load HuggingFace dataset from disk
dataset = load_from_disk(embedding_path)
return dataset
# =============================================================================
# Data Splitting
# =============================================================================
def split_by_session(
features: np.ndarray,
labels: np.ndarray,
session_ids: np.ndarray,
train_session: str = "1",
test_session: str = "2",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
"""
Split data by recording session for temporal generalization evaluation.
This split strategy tests whether learned patterns generalize across time.
Training on session 1 and testing on session 2 simulates real-world deployment
where models must work on future recordings.
Args:
features: Feature matrix of shape (num_samples, num_features)
labels: Label array of shape (num_samples,)
session_ids: Session identifier for each sample
train_session: Session ID to use for training (default: "1")
test_session: Session ID to use for testing (default: "2")
Returns:
Tuple of (X_train, X_test, y_train, y_test, split_description)
"""
# Create boolean masks for train and test sets
train_mask = session_ids == train_session
test_mask = session_ids == test_session
# Apply masks to create train/test splits
X_train = features[train_mask]
X_test = features[test_mask]
y_train = labels[train_mask]
y_test = labels[test_mask]
split_description = f"session({train_session}->train, {test_session}->test)"
return X_train, X_test, y_train, y_test, split_description
def split_random(
features: np.ndarray,
labels: np.ndarray,
test_size: float = 0.2,
random_state: int = 0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
"""
Split data randomly with stratification for in-distribution evaluation.
Stratification ensures each class has proportional representation in
both train and test sets, preventing class imbalance issues.
Args:
features: Feature matrix of shape (num_samples, num_features)
labels: Label array of shape (num_samples,)
test_size: Fraction of data to use for testing (default: 0.2)
random_state: Random seed for reproducibility
Returns:
Tuple of (X_train, X_test, y_train, y_test, split_description)
"""
X_train, X_test, y_train, y_test = train_test_split(
features,
labels,
test_size=test_size,
random_state=random_state,
stratify=labels,
)
split_description = "random"
return X_train, X_test, y_train, y_test, split_description
# =============================================================================
# Model Training and Evaluation
# =============================================================================
def create_classifier_pipeline(
n_estimators: int = 200,
max_depth: int = None,
min_samples_split: int = 2,
min_samples_leaf: int = 1,
random_state: int = 0,
) -> Pipeline:
"""
Create a scikit-learn pipeline for user classification using Random Forest.
Pipeline Architecture:
----------------------
1. StandardScaler: Z-score normalization of features
- Centers features (mean=0) and scales to unit variance (std=1)
- While Random Forest is scale-invariant, scaling helps with
consistent feature importance interpretation
2. RandomForestClassifier: Ensemble of decision trees
- Builds multiple decision trees on random subsets of data (bagging)
- Each tree uses random subset of features at each split
- Final prediction is majority vote across all trees
- Provides feature_importances_ for interpretability
Random Forest Hyperparameters:
------------------------------
- n_estimators: Number of trees in the forest (more = better but slower)
- max_depth: Maximum tree depth (None = expand until pure leaves)
- min_samples_split: Minimum samples to split internal node
- min_samples_leaf: Minimum samples required at leaf node
Args:
n_estimators: Number of trees (default: 200)
max_depth: Maximum depth of trees (default: None, fully grown)
min_samples_split: Min samples for splitting (default: 2)
min_samples_leaf: Min samples at leaf (default: 1)
random_state: Random seed for reproducibility
Returns:
Configured sklearn Pipeline ready for .fit() and .predict()
"""
pipeline = Pipeline([
# Step 1: Feature normalization
("scaler", StandardScaler(
with_mean=True, # Subtract mean (center the data)
with_std=True, # Divide by standard deviation (scale to unit variance)
)),
# Step 2: Random Forest classification
("classifier", RandomForestClassifier(
n_estimators=n_estimators, # Number of trees in the forest
max_depth=max_depth, # Maximum depth (None = unlimited)
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
n_jobs=-1, # Use all CPU cores for parallel tree building
random_state=random_state, # For reproducibility
class_weight="balanced", # Handle class imbalance automatically
oob_score=True, # Enable out-of-bag error estimation
)),
])
return pipeline
def evaluate_classifier(
y_true: np.ndarray,
y_pred: np.ndarray,
) -> Tuple[float, float, str, np.ndarray, list]:
"""
Compute classification metrics and confusion matrix.
Metrics Computed:
- Accuracy: Overall fraction of correct predictions
- Macro F1: Average F1 across all classes
- Per-class report: Precision, recall, F1 for each user
- Confusion matrix: Detailed breakdown of predictions vs ground truth
Args:
y_true: Ground truth labels
y_pred: Predicted labels
Returns:
Tuple of (accuracy, f1, classification_report, confusion_matrix, class_labels)
"""
# Compute scalar metrics
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro') # Multiclass: use macro-averaged F1
# Generate detailed per-class report
report = classification_report(y_true, y_pred, digits=4)
# Compute confusion matrix: Get all unique classes from both true and predicted labels
class_labels = sorted(set(y_true) | set(y_pred))
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
return accuracy, f1, report, cm, class_labels
# =============================================================================
# Visualization
# =============================================================================
def save_confusion_matrix_plot(
confusion_mat: np.ndarray,
class_labels: list,
output_path: str,
) -> None:
"""
Create and save a confusion matrix heatmap visualization.
The confusion matrix shows:
- Rows: True class labels
- Columns: Predicted class labels
- Cell values: Count of samples with that (true, predicted) combination
- Diagonal: Correct predictions
- Off-diagonal: Misclassifications
Args:
confusion_mat: Square matrix of shape (num_classes, num_classes)
class_labels: List of class names for axis labels
output_path: File path to save the plot (PDF format recommended)
"""
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot heatmap using imshow
im = ax.imshow(confusion_mat, aspect="auto", cmap="Blues")
# Add colorbar to show value scale
plt.colorbar(im, ax=ax)
# Set labels and title
ax.set_title("Confusion Matrix (User Classification)")
ax.set_xlabel("Predicted User")
ax.set_ylabel("True User")
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved confusion matrix: {output_path}")
def save_metrics_report(
output_path: str,
split_description: str,
train_size: int,
test_size: int,
accuracy: float,
macro_f1: float,
classification_report_str: str,
oob_score: float = None,
silhouette: float = None,
) -> None:
with open(output_path, "w", encoding="utf-8") as f:
# Write summary statistics
f.write(f"split_used : {split_description}\n")
f.write(f"train_size : {train_size}\n")
f.write(f"test_size : {test_size}\n")
# Silhouette score (embedding quality metric)
if silhouette is not None:
f.write(f"silhouette : {silhouette:.4f}\n")
f.write(f"accuracy : {accuracy:.4f}\n")
f.write(f"macro_f1 : {macro_f1:.4f}\n")
# Add OOB score if available (Random Forest specific)
if oob_score is not None:
f.write(f"oob_score : {oob_score:.4f}\n")
f.write("\n")
# Write detailed per-class report
f.write(classification_report_str)
print(f"[DONE] Saved metrics report: {output_path}")
def save_feature_importance_plot(
feature_importances: np.ndarray,
output_path: str,
top_k: int = 50,
) -> None:
# Get indices of top-k most important features
top_indices = np.argsort(feature_importances)[::-1][:top_k]
top_importances = feature_importances[top_indices]
# Create figure
fig, ax = plt.subplots(figsize=(12, 8))
# Create horizontal bar chart
y_positions = np.arange(len(top_indices))
ax.barh(y_positions, top_importances, color="steelblue", alpha=0.8)
# Set labels
ax.set_yticks(y_positions)
ax.set_yticklabels([f"emb_{i}" for i in top_indices], fontsize=8)
ax.invert_yaxis() # Highest importance at top
ax.set_xlabel("Feature Importance (Mean Decrease in Impurity)")
ax.set_ylabel("Embedding Dimension")
ax.set_title(f"Top {top_k} Most Important Embedding Dimensions")
# Save as vector PDF
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved feature importance plot: {output_path}")
def save_feature_importance_csv(
feature_importances: np.ndarray,
output_path: str,
) -> None:
importance_df = pd.DataFrame({
"feature": [f"emb_{i}" for i in range(len(feature_importances))],
"importance": feature_importances,
})
importance_df = importance_df.sort_values("importance", ascending=False)
importance_df = importance_df.reset_index(drop=True)
importance_df["rank"] = range(1, len(importance_df) + 1)
importance_df = importance_df[["rank", "feature", "importance"]]
importance_df.to_csv(output_path, index=False)
print(f"[DONE] Saved feature importance CSV: {output_path}")
# =============================================================================
# Command Line Interface
# =============================================================================
def main(
embeddings: str,
out_dir: str,
split_mode: str = "session",
test_size: float = 0.2,
n_estimators: int = 200,
max_depth: int = None,
labels: str = None,
) -> None:
# Validate split_mode argument
if split_mode not in ["session", "random"]:
raise ValueError(f"Invalid split_mode: {split_mode}. Use 'session' or 'random'.")
# Create output directory
os.makedirs(out_dir, exist_ok=True)
print(f"[INFO] Loading embeddings from: {embeddings}")
dataset = load_embeddings_with_metadata(embeddings)
# Filter by sleep stage labels if specified
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: str(x["label"]) in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
X = np.array(dataset["embedding"], dtype=np.float32)
y = np.array([str(uid) for uid in dataset["user_id"]])
session_ids = np.array([str(sid) for sid in dataset["session_id"]])
print(f"[INFO] Loaded {len(y)} samples with {X.shape[1]}-dimensional embeddings")
print(f"[INFO] Number of unique users: {len(np.unique(y))}")
if split_mode == "session":
has_session_1 = (session_ids == "1").sum() > 0
has_session_2 = (session_ids == "2").sum() > 0
if has_session_1 and has_session_2:
X_train, X_test, y_train, y_test, split_desc = split_by_session(
X, y, session_ids
)
else:
print("[WARN] Missing session data, falling back to random split")
X_train, X_test, y_train, y_test, split_desc = split_random(
X, y, test_size
)
split_desc = "random(fallback)"
else:
X_train, X_test, y_train, y_test, split_desc = split_random(
X, y, test_size
)
print(f"[INFO] Split: {split_desc}")
print(f"[INFO] Train size: {len(y_train)}, Test size: {len(y_test)}")
silhouette_avg = silhouette_score(X, y)
print(f"[INFO] Silhouette Score (user clusters): {silhouette_avg:.4f}")
print("[INFO] Training Random Forest classifier...")
print(f"[INFO] Hyperparameters: n_estimators={n_estimators}, max_depth={max_depth}")
classifier = create_classifier_pipeline(
n_estimators=n_estimators,
max_depth=max_depth,
)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
rf_model = classifier.named_steps["classifier"]
oob_score = rf_model.oob_score_ if hasattr(rf_model, "oob_score_") else None
feature_importances = rf_model.feature_importances_
print("[INFO] Evaluating classifier performance...")
accuracy, macro_f1, report, cm, classes = evaluate_classifier(y_test, y_pred)
metrics_path = os.path.join(out_dir, "user_cls_metrics.txt")
save_metrics_report(
output_path=metrics_path,
split_description=split_desc,
train_size=len(y_train),
test_size=len(y_test),
accuracy=accuracy,
macro_f1=macro_f1,
classification_report_str=report,
oob_score=oob_score,
silhouette=silhouette_avg,
)
confusion_path = os.path.join(out_dir, "user_cls_confusion.pdf")
save_confusion_matrix_plot(cm, classes, confusion_path)
importance_plot_path = os.path.join(out_dir, "feature_importance.pdf")
save_feature_importance_plot(feature_importances, importance_plot_path)
importance_csv_path = os.path.join(out_dir, "feature_importance.csv")
save_feature_importance_csv(feature_importances, importance_csv_path)
print("\n" + "=" * 50)
print("RANDOM FOREST CLASSIFICATION RESULTS")
print("=" * 50)
print(f"Split Strategy : {split_desc}")
print(f"Silhouette : {silhouette_avg:.4f}")
print(f"Accuracy : {accuracy:.4f}")
print(f"Macro F1 : {macro_f1:.4f}")
if oob_score is not None:
print(f"OOB Score : {oob_score:.4f}")
print(f"Top 5 Important Features:")
top5_idx = np.argsort(feature_importances)[::-1][:5]
for rank, idx in enumerate(top5_idx, 1):
print(f" {rank}. emb_{idx}: {feature_importances[idx]:.4f}")
print("=" * 50)
if __name__ == "__main__":
Fire(main)

View File

@@ -1,982 +0,0 @@
"""
Chronos-2 PPG-BP Embedding Extraction and Visualization Pipeline
This module provides functionality to:
1. Extract embeddings from univariate PPG signals (PPG-BP dataset) using Chronos-2
2. Visualize embeddings using t-SNE colored by various metadata categories
Chronos-2 Overview:
Chronos-2 is a time series foundation model developed by Amazon that converts
time series forecasting into a language modeling task. It tokenizes time series
values and generates probabilistic forecasts using a transformer architecture.
Embedding Strategy:
We use Chronos-2's internal encoder hidden states as embeddings. For univariate
PPG signals, each signal is treated as a single-variate time series, producing
a 512-dimensional embedding after pooling.
Processing Pipeline:
1. Load PPG signals via PPGBPLoader (3 segments per subject)
2. Each segment is embedded independently through Chronos-2 encoder
3. Encoder hidden states are pooled to produce a fixed-size embedding
4. Metadata from the dataset xlsx is attached for visualization
Usage:
# Extract embeddings
python gen_plot.py extract --base_dir /path/to/ppgbp/DataFile --out_dir ./embeddings
# Visualize with t-SNE
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
# LazyPredict: predict hypertension (or other target) from Chronos embeddings
python gen_plot.py lazy_eval --emb_dir ./embeddings --target hypertension --out_csv ./lazy_hypertension.csv
"""
import os
import sys
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from datasets import Dataset, load_from_disk
from chronos import Chronos2Pipeline
from fire import Fire
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder
try:
from lazypredict.Supervised import LazyClassifier, LazyRegressor
HAS_LAZYPREDICT = True
except ImportError:
HAS_LAZYPREDICT = False
# Add parent directories to path for importing PPGBPLoader
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from ppgbp_loader import PPGBPLoader
# Target column options for lazy_eval: classification vs regression
LAZY_EVAL_CLASSIFICATION_TARGETS = [
"hypertension",
"sex",
"age_category",
"height_category",
"weight_category",
"systolic_category",
"diastolic_category",
"heart_rate_category",
"bmi_category",
]
LAZY_EVAL_REGRESSION_TARGETS = [
"age",
"height",
"weight",
"systolic",
"diastolic",
"heart_rate",
"bmi",
]
LAZY_EVAL_ALL_TARGETS = LAZY_EVAL_CLASSIFICATION_TARGETS + LAZY_EVAL_REGRESSION_TARGETS
# =============================================================================
# Metadata Categorization Helpers
# =============================================================================
def _categorize_age(age: Optional[int]) -> Optional[str]:
if age is None:
return None
if age <= 35:
return "Young (21-35)"
elif age <= 55:
return "Middle (36-55)"
elif age <= 75:
return "Senior (56-75)"
else:
return "Old (76+)"
def _categorize_height(height: Optional[int]) -> Optional[str]:
if height is None:
return None
if height <= 157:
return "Short (145-157)"
elif height <= 170:
return "Below Avg (158-170)"
elif height <= 183:
return "Above Avg (171-183)"
else:
return "Tall (184+)"
def _categorize_weight(weight: Optional[int]) -> Optional[str]:
if weight is None:
return None
if weight <= 52:
return "Light (36-52)"
elif weight <= 69:
return "Below Avg (53-69)"
elif weight <= 86:
return "Above Avg (70-86)"
else:
return "Heavy (87+)"
def _categorize_systolic(sbp: Optional[int]) -> Optional[str]:
if sbp is None:
return None
if sbp < 90:
return "Low (<90)"
elif sbp < 120:
return "Normal (90-119)"
elif sbp < 140:
return "Elevated (120-139)"
else:
return "High (140+)"
def _categorize_diastolic(dbp: Optional[int]) -> Optional[str]:
if dbp is None:
return None
if dbp < 60:
return "Low (<60)"
elif dbp < 80:
return "Normal (60-79)"
elif dbp < 90:
return "Elevated (80-89)"
else:
return "High (90+)"
def _categorize_heart_rate(hr: Optional[int]) -> Optional[str]:
if hr is None:
return None
if hr <= 65:
return "Low (52-65)"
elif hr <= 79:
return "Normal (66-79)"
elif hr <= 93:
return "Elevated (80-93)"
else:
return "High (94+)"
def _categorize_bmi(bmi: Optional[float]) -> Optional[str]:
if bmi is None:
return None
if bmi < 18.5:
return "Underweight (<18.5)"
elif bmi < 25:
return "Normal (18.5-24.9)"
elif bmi < 30:
return "Overweight (25-29.9)"
else:
return "Obese (30+)"
def _safe_int(val, sentinel=0):
"""Convert a value to int, returning None if it equals the sentinel or is invalid."""
if val is None or (isinstance(val, float) and pd.isna(val)):
return None
try:
v = int(val)
return None if v == sentinel else v
except (ValueError, TypeError):
return None
def _safe_float(val, sentinel=0.0):
"""Convert a value to float, returning None if it equals the sentinel or is invalid."""
if val is None or (isinstance(val, float) and pd.isna(val)):
return None
try:
v = float(val)
return None if v == sentinel else v
except (ValueError, TypeError):
return None
def _extract_metadata_fields(metadata: Dict) -> Dict[str, Any]:
"""
Extract and sanitize metadata fields from a PPGBPLoader metadata dict.
PPGBPLoader renames some columns (sex, age, sysbp, diasbp, hr, bmi) but
leaves others with original names (Height(cm), Weight(kg), Hypertension).
fillna(0) is applied, so 0 may represent missing values for some fields.
Returns a dict with standardized keys.
"""
# Sex: "M"/"F" string, or 0 if missing (from fillna)
sex = metadata.get("sex", None)
if sex is not None and sex != 0 and str(sex).strip() in ("M", "F"):
sex = str(sex).strip()
else:
sex = None
age = _safe_int(metadata.get("age", None))
height = _safe_int(metadata.get("Height(cm)", None))
weight = _safe_int(metadata.get("Weight(kg)", None))
sysbp = _safe_int(metadata.get("sysbp", None))
diasbp = _safe_int(metadata.get("diasbp", None))
hr = _safe_int(metadata.get("hr", None))
bmi = _safe_float(metadata.get("bmi", None))
# Hypertension: may be string (e.g. "Stage 2 hypertension") or numeric; keep as string
hypertension = metadata.get("Hypertension", None)
if hypertension is not None and not (isinstance(hypertension, float) and pd.isna(hypertension)):
hypertension = str(hypertension).strip()
else:
hypertension = None
return {
"sex": sex,
"age": age,
"height": height,
"weight": weight,
"sysbp": sysbp,
"diasbp": diasbp,
"hr": hr,
"bmi": bmi,
"hypertension": hypertension,
"age_category": _categorize_age(age),
"height_category": _categorize_height(height),
"weight_category": _categorize_weight(weight),
"systolic_category": _categorize_systolic(sysbp),
"diastolic_category": _categorize_diastolic(diasbp),
"heart_rate_category": _categorize_heart_rate(hr),
"bmi_category": _categorize_bmi(bmi),
}
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class Chronos_2_Embedder:
"""
Extracts fixed-dimensional embeddings from univariate PPG signals using Chronos-2.
Architecture:
Input PPG Signal -> Patching -> Encoder (6 layers) -> Hidden States -> Pooling -> Embedding
For univariate PPG:
Input shape: (1, 1, L) -- batch=1, variates=1, length=L
Output shape: (512,) -- d_model=512 after pooling
Pooling Strategies:
- mean: Average across all patch tokens (excluding special tokens)
- cls: Use the [REG] token embedding (similar to BERT's [CLS])
"""
def __init__(
self,
model_name: str = "amazon/chronos-2",
device_map: str = None,
pooling_strategy: str = "mean",
):
self.pooling_strategy = pooling_strategy
if device_map is None:
device_map = "cuda" if torch.cuda.is_available() else "cpu"
self.device_map = device_map
self.device = torch.device(
"cuda" if device_map.startswith("cuda") and torch.cuda.is_available() else "cpu"
)
print(f"[INFO] Loading Chronos-2 model: {model_name}")
print(f"[INFO] Device: {device_map}")
print(f"[INFO] Pooling strategy: {pooling_strategy}")
self.pipeline: Chronos2Pipeline = Chronos2Pipeline.from_pretrained(
model_name,
device_map=device_map,
)
@torch.no_grad()
def compute_embedding(self, signal: np.ndarray) -> np.ndarray:
"""
Compute embedding for a single univariate PPG signal.
Args:
signal: 1D numpy array of PPG signal values (variable length OK)
Returns:
Embedding vector of shape (d_model,) = (512,)
"""
# Reshape to (B=1, V=1, L) for Chronos-2 input
x_input = signal.reshape(1, 1, -1).astype(np.float32)
# Get encoder hidden states
# Returns: list of tensors (one per batch item), each (V, num_patches+2, d_model)
embeddings_list, _ = self.pipeline.embed(x_input)
emb = embeddings_list[0] # (1, num_patches+2, 512) for V=1
if self.pooling_strategy == "cls":
# Use the [REG] token (first token)
pooled = emb[0, 0, :] # (d_model,)
else: # "mean" pooling (default)
# Average across all patch tokens, skip [REG] (first) and masked patch (last)
pooled = emb[0, 1:-1, :].mean(dim=0) # (d_model,)
return pooled.cpu().numpy().astype(np.float32)
def extract_embeddings(
self,
base_dir: str,
split: str = "all",
) -> Dataset:
"""
Extract embeddings from all subjects/segments in the PPG-BP dataset.
Each subject has 3 PPG segments. Each segment is embedded independently
via Chronos-2, resulting in 3 data points per subject sharing the same metadata.
Args:
base_dir: Root directory of PPG-BP dataset (containing 0_subject/ and .xlsx)
split: Data split ('train', 'val', 'test', 'all')
Returns:
HuggingFace Dataset with columns:
- subject_id, segment_id (identifiers)
- sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension
- age_category, height_category, weight_category, systolic_category, etc.
- embedding (512-dim vector)
"""
loader = PPGBPLoader(base_dir=base_dir, split="all")
n_subjects = len(loader)
train_id_set = set(int(x) for x in loader.train_ids)
test_id_set = set(int(x) for x in loader.test_ids)
print(f"[INFO] Loaded {n_subjects} subjects (train={len(train_id_set)}, test={len(test_id_set)})")
# Accumulator lists
all_embeddings = []
all_subject_ids = []
all_segment_ids = []
all_splits = []
all_sexes = []
all_ages = []
all_heights = []
all_weights = []
all_sysbps = []
all_diasbps = []
all_hrs = []
all_bmis = []
all_hypertensions = []
all_age_categories = []
all_height_categories = []
all_weight_categories = []
all_systolic_categories = []
all_diastolic_categories = []
all_hr_categories = []
all_bmi_categories = []
total_segments = 0
for subj_idx in range(n_subjects):
metadata, signals = loader[subj_idx]
subject_id_int = int(metadata.get("subject_ID", 0))
subject_id = str(subject_id_int)
split_label = "train" if subject_id_int in train_id_set else "test"
# Extract and categorize metadata once per subject
fields = _extract_metadata_fields(metadata)
# Process each PPG segment
for seg_idx, signal in enumerate(signals):
emb = self.compute_embedding(signal)
all_embeddings.append(emb.tolist())
all_subject_ids.append(subject_id)
all_segment_ids.append(seg_idx + 1)
all_splits.append(split_label)
all_sexes.append(fields["sex"])
all_ages.append(fields["age"])
all_heights.append(fields["height"])
all_weights.append(fields["weight"])
all_sysbps.append(fields["sysbp"])
all_diasbps.append(fields["diasbp"])
all_hrs.append(fields["hr"])
all_bmis.append(fields["bmi"])
all_hypertensions.append(fields["hypertension"])
all_age_categories.append(fields["age_category"])
all_height_categories.append(fields["height_category"])
all_weight_categories.append(fields["weight_category"])
all_systolic_categories.append(fields["systolic_category"])
all_diastolic_categories.append(fields["diastolic_category"])
all_hr_categories.append(fields["heart_rate_category"])
all_bmi_categories.append(fields["bmi_category"])
total_segments += 1
if (subj_idx + 1) % 20 == 0 or subj_idx == n_subjects - 1:
print(
f"[INFO] Processed {subj_idx + 1}/{n_subjects} subjects "
f"({total_segments} segments)"
)
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"subject_id": all_subject_ids,
"segment_id": all_segment_ids,
"split": all_splits,
"sex": all_sexes,
"age": all_ages,
"age_category": all_age_categories,
"height": all_heights,
"height_category": all_height_categories,
"weight": all_weights,
"weight_category": all_weight_categories,
"systolic": all_sysbps,
"systolic_category": all_systolic_categories,
"diastolic": all_diasbps,
"diastolic_category": all_diastolic_categories,
"heart_rate": all_hrs,
"heart_rate_category": all_hr_categories,
"bmi": all_bmis,
"bmi_category": all_bmi_categories,
"hypertension": all_hypertensions,
"embedding": all_embeddings,
})
print(f"[INFO] Total segments: {len(result_dataset)}")
if len(result_dataset) > 0:
print(f"[INFO] Embedding dim: {len(result_dataset[0]['embedding'])}")
return result_dataset
def save_embeddings(self, dataset: Dataset, output_dir: str) -> None:
"""Save embeddings dataset to disk in HuggingFace format."""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
if len(dataset) > 0:
print(
f"[DONE] Total samples: {len(dataset)}, "
f"Embedding dim: {len(dataset[0]['embedding'])}"
)
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""Load saved embeddings dataset from disk."""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0,
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
Args:
embeddings: Array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity (typically 5-50). Rule of thumb: ~ sqrt(N)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0,
perplexity=perplexity,
max_iter=1000,
init="random",
learning_rate="auto",
)
return tsne.fit_transform(embeddings)
def create_scatter_plot_continuous(
coordinates: np.ndarray,
values: np.ndarray,
title: str,
output_path: str,
label: str = "Value",
cmap: str = "viridis",
) -> None:
"""
Create a 2D scatter plot colored by a continuous variable.
Args:
coordinates: 2D array of shape (num_points, 2)
values: Values for each point (may contain None)
title: Plot title
output_path: File path to save the plot (PDF)
label: Colorbar label
cmap: Matplotlib colormap name
"""
values_float = np.array([float(v) if v is not None else np.nan for v in values])
valid_mask = ~np.isnan(values_float)
valid_coords = coordinates[valid_mask]
valid_values = values_float[valid_mask]
if len(valid_values) == 0:
print(f"[WARN] No valid {label} data found. Skipping plot: {output_path}")
return
fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_values,
cmap=cmap,
alpha=0.7,
s=50,
)
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label(label)
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] {label} range: {valid_values.min():.1f} - {valid_values.max():.1f}")
def create_scatter_plot_categorical(
coordinates: np.ndarray,
categories: np.ndarray,
title: str,
output_path: str,
label: str = "Category",
) -> None:
"""
Create a 2D scatter plot colored by a categorical variable.
Args:
coordinates: 2D array of shape (num_points, 2)
categories: Category values for each point (may contain None)
title: Plot title
output_path: File path to save the plot (PDF)
label: Legend title
"""
valid_mask = np.array([c is not None and str(c) != "None" for c in categories])
valid_coords = coordinates[valid_mask]
valid_cats = np.array(categories)[valid_mask]
if len(valid_cats) == 0:
print(f"[WARN] No valid {label} data found. Skipping plot: {output_path}")
return
unique_cats = sorted(set(str(c) for c in valid_cats))
colormap = plt.cm.tab10 if len(unique_cats) <= 10 else plt.cm.tab20
colors = {cat: colormap(i % 20) for i, cat in enumerate(unique_cats)}
fig, ax = plt.subplots(figsize=(10, 8))
for cat in unique_cats:
mask = np.array([str(c) == cat for c in valid_cats])
ax.scatter(
valid_coords[mask, 0],
valid_coords[mask, 1],
c=[colors[cat]],
label=cat,
alpha=0.7,
s=50,
)
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(title=label, loc="best", markerscale=2)
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] {label} categories: {unique_cats}")
# =============================================================================
# LazyPredict multi-seed helper
# =============================================================================
def _run_lazy_one_split(
X: np.ndarray,
y_raw: List[Any],
subject_ids_valid: List[int],
train_id_set: set,
test_id_set: set,
is_classification: bool,
verbose: int = 0,
ignore_warnings: bool = True,
) -> pd.DataFrame:
"""Run LazyClassifier or LazyRegressor for one train/test split; return metrics DataFrame."""
train_mask = np.array([int(s) in train_id_set for s in subject_ids_valid])
test_mask = np.array([int(s) in test_id_set for s in subject_ids_valid])
if train_mask.sum() == 0 or test_mask.sum() == 0:
raise ValueError("Empty train or test set for this split.")
X_train, X_test = X[train_mask], X[test_mask]
y_train_raw = [y_raw[i] for i in range(len(y_raw)) if train_mask[i]]
y_test_raw = [y_raw[i] for i in range(len(y_raw)) if test_mask[i]]
if is_classification:
le = LabelEncoder()
y_train = le.fit_transform([str(v) for v in y_train_raw])
y_test = le.transform([str(v) for v in y_test_raw])
clf = LazyClassifier(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
models, _ = clf.fit(X_train, X_test, y_train, y_test)
else:
y_train = np.array([float(v) for v in y_train_raw], dtype=np.float64)
y_test = np.array([float(v) for v in y_test_raw], dtype=np.float64)
reg = LazyRegressor(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
models, _ = reg.fit(X_train, X_test, y_train, y_test)
# Ensure model name is index for aggregation across seeds
if "Model" in models.columns:
models = models.set_index("Model")
return models
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
CLI for Chronos-2 PPG-BP embedding extraction and t-SNE visualization.
Commands:
extract - Compute Chronos-2 embeddings from PPG signals
plot - Generate t-SNE scatter plots colored by metadata
lazy_eval - LazyPredict: X=embedding, y=target (default hypertension); single split
lazy_eval_multi_seed - LazyPredict over 10 seeds for all targets; save {target}.csv (mean +/- std)
"""
def extract(
self,
base_dir: str,
out_dir: str,
model: str = "amazon/chronos-2",
pooling: str = "mean",
split: str = "all",
) -> None:
"""
Extract Chronos-2 embeddings from PPG-BP signals.
Args:
base_dir: PPG-BP dataset root (contains 0_subject/ folder and .xlsx)
out_dir: Output directory for the HuggingFace embeddings dataset
model: Chronos-2 model name (default: amazon/chronos-2)
pooling: Pooling strategy - 'mean' or 'cls' (default: mean)
split: Data split - 'train', 'val', 'test', 'all' (default: all)
"""
if pooling not in ("mean", "cls"):
raise ValueError(f"Invalid pooling: {pooling}. Use 'mean' or 'cls'.")
embedder = Chronos_2_Embedder(
model_name=model,
pooling_strategy=pooling,
)
dataset = embedder.extract_embeddings(base_dir, "all")
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str,
out_dir: str,
perplexity: float = 30.0,
subjects: str = None,
num_subjects: int = 0,
) -> None:
"""
Visualize embeddings with t-SNE, colored by each metadata column.
Generates plots for (1) train, (2) test, (3) train+test, each with all categories.
Args:
emb_dir: Directory containing the saved HuggingFace embeddings dataset
out_dir: Output directory for PDF plots (subdirs: train/, test/, all/)
perplexity: t-SNE perplexity parameter (default: 30.0)
subjects: Comma-separated subject IDs to include (e.g., '2,6,8')
num_subjects: Include only first N subjects, 0 = all (default: 0)
"""
# Load saved embeddings (must have "split" column)
full_dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
if "split" not in full_dataset.column_names:
raise ValueError(
"Embeddings dataset missing 'split' column. Re-run extract with updated PPGBPLoader."
)
# Optional subject filtering
if subjects:
subject_list = [s.strip() for s in subjects.split(",")]
full_dataset = full_dataset.filter(lambda x: x["subject_id"] in subject_list)
print(f"[INFO] Filtered to subjects: {subject_list}")
elif num_subjects > 0:
all_subjects = sorted(set(full_dataset["subject_id"]))
selected = all_subjects[:num_subjects]
full_dataset = full_dataset.filter(lambda x: x["subject_id"] in selected)
print(f"[INFO] Selected first {num_subjects} subjects: {selected}")
continuous_vars = [
("age", "Age (years)", "viridis"),
("height", "Height (cm)", "plasma"),
("weight", "Weight (kg)", "cividis"),
("systolic", "Systolic BP (mmHg)", "Reds"),
("diastolic", "Diastolic BP (mmHg)", "Blues"),
("heart_rate", "Heart Rate (bpm)", "Oranges"),
("bmi", "BMI (kg/m²)", "Greens"),
]
categorical_vars = [
("subject_id", "Subject ID"),
("sex", "Sex"),
("hypertension", "Hypertension"),
("age_category", "Age Category"),
("height_category", "Height Category"),
("weight_category", "Weight Category"),
("systolic_category", "Systolic BP Category"),
("diastolic_category", "Diastolic BP Category"),
("heart_rate_category", "Heart Rate Category"),
("bmi_category", "BMI Category"),
]
for split_name in ["train", "test", "all"]:
if split_name == "all":
dataset = full_dataset
else:
dataset = full_dataset.filter(lambda x: x["split"] == split_name)
n = len(dataset)
if n == 0:
print(f"[WARN] No samples for split '{split_name}', skipping.")
continue
split_out_dir = os.path.join(out_dir, split_name)
os.makedirs(split_out_dir, exist_ok=True)
print(f"[INFO] t-SNE for split '{split_name}': {n} samples -> {split_out_dir}")
embeddings = np.array(dataset["embedding"])
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
for var_name, var_label, cmap in continuous_vars:
if var_name in dataset.column_names:
values = dataset[var_name]
create_scatter_plot_continuous(
coordinates_2d,
values,
f"Chronos-2 t-SNE [{split_name}] (Colored by {var_label})",
os.path.join(split_out_dir, f"tsne_by_{var_name}.pdf"),
label=var_label,
cmap=cmap,
)
for var_name, var_label in categorical_vars:
if var_name in dataset.column_names:
categories = dataset[var_name]
create_scatter_plot_categorical(
coordinates_2d,
categories,
f"Chronos-2 t-SNE [{split_name}] (Colored by {var_label})",
os.path.join(split_out_dir, f"tsne_by_{var_name}.pdf"),
label=var_label,
)
print(f"\n[DONE] All plots saved under: {out_dir} (train/, test/, all/)")
def lazy_eval(
self,
emb_dir: str,
target: str = "hypertension",
out_csv: str = None,
verbose: int = 0,
ignore_warnings: bool = True,
) -> None:
"""
Run LazyPredict (many classifiers or regressors) on Chronos embeddings (X) vs target (y).
Uses the same train/test split as PPGBPLoader (split column in the embeddings dataset).
Args:
emb_dir: Directory containing the saved HuggingFace embeddings dataset
target: Target column for y. Default 'hypertension'.
Classification: hypertension, sex, age_category, height_category,
weight_category, systolic_category, diastolic_category,
heart_rate_category, bmi_category
Regression: age, height, weight, systolic, diastolic, heart_rate, bmi
out_csv: If set, save the models comparison DataFrame to this path (e.g. lazy_hypertension.csv)
verbose: LazyPredict verbose (0 or 1)
ignore_warnings: LazyPredict ignore_warnings
"""
if not HAS_LAZYPREDICT:
raise ImportError("lazypredict is required. Install with: pip install lazypredict")
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
if "split" not in dataset.column_names:
raise ValueError(
"Embeddings dataset missing 'split' column. Re-run extract with updated PPGBPLoader."
)
if target not in dataset.column_names:
raise ValueError(
f"Target '{target}' not in dataset. Available: {dataset.column_names}"
)
X = np.array(dataset["embedding"], dtype=np.float64)
y_raw = list(dataset[target])
# Drop rows where target is missing
valid = [
i for i, v in enumerate(y_raw)
if v is not None and not (isinstance(v, float) and pd.isna(v)) and str(v).strip() != ""
]
if len(valid) == 0:
raise ValueError(f"No valid non-missing values for target '{target}'.")
X = X[valid]
y_raw = [y_raw[i] for i in valid]
splits = [dataset["split"][i] for i in valid]
train_mask = np.array([s == "train" for s in splits])
test_mask = np.array([s == "test" for s in splits])
X_train, X_test = X[train_mask], X[test_mask]
y_train_raw = [y_raw[i] for i in range(len(y_raw)) if splits[i] == "train"]
y_test_raw = [y_raw[i] for i in range(len(y_raw)) if splits[i] == "test"]
is_classification = target in LAZY_EVAL_CLASSIFICATION_TARGETS
if target in LAZY_EVAL_REGRESSION_TARGETS and not is_classification:
is_classification = False
elif target not in LAZY_EVAL_REGRESSION_TARGETS:
is_classification = True
if is_classification:
le = LabelEncoder()
y_train = le.fit_transform([str(v) for v in y_train_raw])
y_test = le.transform([str(v) for v in y_test_raw])
print(f"[INFO] LazyClassifier: X=embedding, y={target} (classes: {list(le.classes_)})")
print(f"[INFO] Train: {len(y_train)}, Test: {len(y_test)}")
clf = LazyClassifier(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
models, predictions = clf.fit(X_train, X_test, y_train, y_test)
else:
y_train = np.array([float(v) for v in y_train_raw], dtype=np.float64)
y_test = np.array([float(v) for v in y_test_raw], dtype=np.float64)
print(f"[INFO] LazyRegressor: X=embedding, y={target}")
print(f"[INFO] Train: {len(y_train)}, Test: {len(y_test)}")
reg = LazyRegressor(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
models, predictions = reg.fit(X_train, X_test, y_train, y_test)
print("\n[RESULTS] Models (sorted by score):")
print(models)
if out_csv:
d = os.path.dirname(out_csv)
if d:
os.makedirs(d, exist_ok=True)
models.to_csv(out_csv)
print(f"\n[DONE] Saved model comparison to {out_csv}")
return None
def lazy_eval_multi_seed(
self,
emb_dir: str,
base_dir: str,
out_dir: str = "./lazy_multi_seed",
n_seeds: int = 10,
verbose: int = 0,
ignore_warnings: bool = True,
targets: str = "all",
) -> None:
"""
Run LazyPredict with n_seeds different train/test splits (from PPGBPLoader seeds).
For each target (classification + regression), saves {target}.csv with mean and std
for all metrics across seeds.
This can run for a long time (10 seeds x 16 targets x many models). For a quick test
use e.g. --n_seeds 2 --targets hypertension
"""
if not HAS_LAZYPREDICT:
raise ImportError("lazypredict is required. Install with: pip install lazypredict")
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
if "subject_id" not in dataset.column_names:
raise ValueError("Embeddings dataset missing 'subject_id' column.")
target_list = (
[t.strip() for t in targets.split(",")] if targets != "all" else LAZY_EVAL_ALL_TARGETS
)
os.makedirs(out_dir, exist_ok=True)
n_targets = len(target_list)
print(f"[INFO] lazy_eval_multi_seed: {n_targets} targets, {n_seeds} seeds each (can take a long time).")
for ti, target in enumerate(target_list):
if target not in dataset.column_names:
print(f"[WARN] Skipping target '{target}' (not in dataset).")
continue
print(f"[INFO] Target '{target}' ({ti + 1}/{n_targets}) ...")
y_raw = list(dataset[target])
valid = [
i
for i, v in enumerate(y_raw)
if v is not None
and not (isinstance(v, float) and pd.isna(v))
and str(v).strip() != ""
]
if len(valid) == 0:
print(f"[WARN] Skipping target '{target}' (no valid values).")
continue
X = np.array(dataset["embedding"], dtype=np.float64)[valid]
y_valid = [y_raw[i] for i in valid]
subject_ids_valid = [dataset["subject_id"][i] for i in valid]
is_classification = target in LAZY_EVAL_CLASSIFICATION_TARGETS
if target in LAZY_EVAL_REGRESSION_TARGETS and not is_classification:
is_classification = False
elif target not in LAZY_EVAL_REGRESSION_TARGETS:
is_classification = True
run_dfs = []
for seed in range(n_seeds):
print(f"[INFO] seed {seed + 1}/{n_seeds} ...", flush=True)
loader = PPGBPLoader(base_dir=base_dir, split="all", seed=seed)
train_id_set = set(int(x) for x in loader.train_ids)
test_id_set = set(int(x) for x in loader.test_ids)
try:
df = _run_lazy_one_split(
X,
y_valid,
subject_ids_valid,
train_id_set,
test_id_set,
is_classification,
verbose=verbose,
ignore_warnings=ignore_warnings,
)
run_dfs.append(df)
except Exception as e:
print(f"[WARN] Seed {seed} for target '{target}' failed: {e}")
continue
if len(run_dfs) == 0:
print(f"[WARN] No successful runs for target '{target}', skipping.")
continue
combined = pd.concat(run_dfs, axis=0)
metric_cols = [
c for c in combined.columns
if c != "Model" and np.issubdtype(combined[c].dtype, np.number)
]
agg_mean = combined.groupby(level=0).mean()
agg_std = combined.groupby(level=0).std().fillna(0)
out = pd.DataFrame(index=agg_mean.index)
out.index.name = "Model"
for col in metric_cols:
out[col + "_mean"] = agg_mean[col].values
out[col + "_std"] = agg_std[col].values
m = agg_mean[col].round(4).astype(str)
s = agg_std[col].round(4).astype(str)
out[col + "_mean_pm_std"] = m + " +/- " + s
out_path = os.path.join(out_dir, target + ".csv")
out.to_csv(out_path)
print(f"[DONE] {target}: {len(run_dfs)} seeds -> {out_path}")
if __name__ == "__main__":
Fire(CLI)

View File

@@ -1,55 +0,0 @@
#!/bin/bash
# =============================================================================
# Chronos-2 PPG-BP Embedding Extraction & t-SNE Visualization
# =============================================================================
set -euo pipefail
# ---- Configuration (edit these variables) ------------------------------------
BASE_DIR="/scratch/10608/aadharsh_aadhithya/repos/data/tsllm"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
MODEL="amazon/chronos-2"
POOLING="mean"
SPLIT="all"
PERPLEXITY=30.0
EMB_DIR="${BASE_DIR}/chronos2_embeddings"
PLOT_DIR="${BASE_DIR}/chronos2_plots"
# ------------------------------------------------------------------------------
echo "============================================="
echo "Chronos-2 PPG-BP Pipeline"
echo "============================================="
echo "Base dir: ${BASE_DIR}"
echo "Model: ${MODEL}"
echo "Pooling: ${POOLING}"
echo "Split: ${SPLIT}"
echo "Embeddings dir: ${EMB_DIR}"
echo "Plots dir: ${PLOT_DIR}"
echo "============================================="
# Step 1: Extract Chronos-2 embeddings from PPG signals
echo ""
echo "[Step 1/2] Extracting Chronos-2 embeddings..."
python "${SCRIPT_DIR}/gen_plot.py" extract \
--base_dir "${BASE_DIR}" \
--out_dir "${EMB_DIR}" \
--model "${MODEL}" \
--pooling "${POOLING}" \
--split "${SPLIT}"
# Step 2: Generate t-SNE plots colored by metadata
echo ""
echo "[Step 2/2] Generating t-SNE plots..."
python "${SCRIPT_DIR}/gen_plot.py" plot \
--emb_dir "${EMB_DIR}" \
--out_dir "${PLOT_DIR}" \
--perplexity "${PERPLEXITY}"
echo ""
echo "============================================="
echo "Done! Outputs:"
echo " Embeddings: ${EMB_DIR}"
echo " Plots: ${PLOT_DIR}"
echo "============================================="

View File

@@ -1,504 +0,0 @@
"""
SBERT Time Series Embedding Extraction and Visualization Pipeline
This module provides functionality to
1. Extract embeddings from time series features using SBERT (Sentence-BERT)
2. Visualize embeddings using dimensionality reduction (t-SNE)
SBERT Overview:
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
and triplet network structures to derive semantically meaningful sentence
embeddings. It's designed for semantic similarity tasks and produces
fixed-size dense vector representations of text.
Key Features:
- Semantic understanding: Captures meaning rather than just word presence
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
- Efficient: Optimized for sentence-level tasks
Embedding Strategy:
We convert time series features into textual descriptions, then use SBERT
to generate embeddings. This approach treats feature vectors as "sentences"
where each feature-value pair is a "word" in the description.
Processing Pipeline:
1. Feature extraction from time series (statistical features)
2. Textualization: Convert features to natural language description
3. SBERT encoding: Generate 384-dimensional embeddings
4. Aggregation: Store embeddings with metadata (user_id, session_id, label)
Usage:
# Extract embeddings
python gen_plot.py extract --data_root /path/to/data --out_dir ./embeddings
# Visualize with t-SNE
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Tuple
import numpy as np
import matplotlib.pyplot as plt
import torch
from datasets import load_from_disk, Dataset
from fire import Fire
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
# =============================================================================
# Constants
# =============================================================================
# EEG channel names used in Sleep-EDF dataset
# Fpz-Cz: Frontal-Central electrode pair
# Pz-Oz: Parietal-Occipital electrode pair
EEG_CHANNEL_1 = "EEG Fpz-Cz"
EEG_CHANNEL_2 = "EEG Pz-Oz"
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class SBERT:
"""
Extracts fixed-dimensional embeddings from time series features using SBERT.
Uses Sentence-BERT to convert textualized feature descriptions into dense
vector representations. The textualization process converts statistical
features into natural language, which SBERT then encodes semantically.
Architecture:
Time Series → Feature Extraction → Textualization → SBERT Encoder → Embedding
Processing Pipeline:
1. Extract statistical features from time series (mean, std, etc.)
2. Textualize: Convert features to natural language description
3. SBERT encoding: Generate 384-dimensional semantic embeddings
4. Output: Fixed-size embedding vector per sample
Model: all-MiniLM-L6-v2
- Lightweight BERT variant optimized for sentence embeddings
- 384-dimensional output embeddings
- Fast inference with good semantic understanding
Attributes:
model: SentenceTransformer instance for encoding textualized features
"""
def __init__(self):
"""
Initialize the SBERT embedder with pre-trained model.
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
optimized for sentence similarity tasks.
"""
self.model = SentenceTransformer("all-MiniLM-L6-v2")
@staticmethod
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
"""
Discover all user/session directories under data_root.
Uses glob pattern matching for cleaner directory traversal.
Expected structure: data_root/user_id/session_id/
Args:
data_root: Root directory containing user/session subfolders
Returns:
List of (user_id, session_id, session_path) tuples
"""
discovered_paths = []
# Use glob to find all session directories (2 levels deep)
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
if not os.path.isdir(session_path):
continue
# Extract user_id and session_id from path
session_id = os.path.basename(session_path)
user_id = os.path.basename(os.path.dirname(session_path))
discovered_paths.append((user_id, session_id, session_path))
return discovered_paths
def textualize_sample(self, sample: Dict[str, Any]) -> str:
"""
Convert a feature dictionary into a natural language description.
This textualization step is crucial for SBERT, which expects text input.
The description provides context about the sleep stage and lists all
extracted features in a structured format.
Args:
sample: Dictionary mapping feature names to their values
Example: {"mean": 0.5, "std": 0.2, "max": 1.0, ...}
Returns:
Natural language string describing the features
"""
sentence = (
"While sleeping (one of the stages: W, N1, N2, N3, REM), "
"I have sensor data features measured from two channels: EEG Fpz-Cz and EEG Pz-Oz.\n"
"The features are as follows: \n"
)
for k, v in sample.items():
sentence += f" - {k}: {self.format_feature(v)}\n"
sentence.strip()
return sentence
def format_feature(self, value: Any) -> str:
"""
Format a feature value for textual representation.
Floats are rounded to 2 decimal places for readability,
other types are converted to strings.
Args:
value: Feature value (float, int, or other type)
Returns:
Formatted string representation
"""
if isinstance(value, float):
return f"{value:.2f}"
return str(value)
def compute_embedding(self, batch: Dict[str, Any]) -> np.ndarray:
"""
Generate embedding vectors from feature batches using SBERT.
Processing Pipeline:
1. Extract feature dictionaries from batch
2. Textualize each sample's features into natural language
3. Encode textual descriptions using SBERT
4. Return fixed-size embedding vectors
Args:
batch: HuggingFace dataset batch from slicing (dataset[start:end])
Format: {"features": [{"mean": 0.5, "std": 0.2, ...}, ...]}
Returns:
Embedding array of shape (batch_size, embedding_dim)
For all-MiniLM-L6-v2: (batch_size, 384)
"""
# Extract feature dictionaries from batch
samples = batch["features"]
# Convert each sample's features to text
text_samples = []
for sample in samples:
text_samples.append(self.textualize_sample(sample))
# Encode text descriptions using SBERT
# Returns numpy array of shape (batch_size, 384)
embeddings = self.model.encode(text_samples)
return embeddings
def extract_embeddings(
self,
data_root: str,
batch_size: int = 32,
label: str = "N1"
) -> Dataset:
"""
Extract embeddings from all sessions under the data root directory.
Iterates through all user/session combinations, processes features
in batches, and aggregates results with metadata. Filters by sleep
stage label to focus on specific sleep stages.
Args:
data_root: Root directory containing user/session data folders
batch_size: Number of samples to process together.
Larger = faster but more memory.
32 is a good balance for most systems.
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM")
Returns:
HuggingFace Dataset with columns:
- user_id, session_id, idx, label (metadata)
- embedding (384-dim vector from all-MiniLM-L6-v2)
"""
session_paths = self.discover_session_paths(data_root)
print(f"[INFO] Discovered {len(session_paths)} sessions")
all_embeddings = []
all_user_ids = []
all_session_ids = []
all_idxs = []
all_labels = []
for user_id, session_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# Shuffle dataset for randomness
dataset = dataset.shuffle(seed=0)
# Filter by sleep stage label
dataset = dataset.filter(lambda x: x["label"] == label)
num_samples = len(dataset)
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Process in batches to manage memory
for batch_start in range(0, num_samples, batch_size):
batch_end = min(batch_start + batch_size, num_samples)
# Slice dataset to get batch
batch = dataset[batch_start:batch_end]
# Compute embeddings
embeddings = self.compute_embedding(batch)
# Collect embeddings and metadata
for i in range(embeddings.shape[0]):
all_embeddings.append(embeddings[i].tolist())
all_user_ids.append(str(batch["user_id"][i]))
all_session_ids.append(str(batch["session_id"][i]))
all_idxs.append(int(batch["idx"][i]))
all_labels.append(str(batch["label"][i]))
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
"session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"embedding": all_embeddings,
})
print(f"[INFO] Total samples: {len(result_dataset)}")
return result_dataset
def save_embeddings(
self,
dataset: Dataset,
output_dir: str
) -> None:
"""
Save embeddings dataset to disk in HuggingFace format.
Args:
dataset: HuggingFace Dataset containing embeddings and metadata
output_dir: Directory path to save the dataset
"""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load saved embeddings dataset from disk.
Args:
embedding_dir: Directory path containing the saved HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
dimensionality reduction technique that preserves local structure.
Points that are similar in high dimensions stay close in 2D.
Args:
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity parameter (typically 5-50).
Higher values consider more neighbors, creating smoother layouts.
Rule of thumb: perplexity ~ sqrt(num_samples)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0, # For reproducibility
perplexity=perplexity,
max_iter=1000, # Usually sufficient for convergence
init='random',
learning_rate='auto', # Let sklearn choose optimal learning rate
)
return tsne.fit_transform(embeddings)
def create_scatter_plot(
coordinates: np.ndarray,
labels: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot with categorical coloring.
Args:
coordinates: 2D array of shape (num_points, 2)
labels: Category labels for each point (string array)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Get unique labels for legend
unique_labels = sorted(set(labels))
# Select colormap based on number of categories
# tab10: 10 distinct colors, tab20: 20 distinct colors
colormap = plt.cm.tab10 if len(unique_labels) <= 10 else plt.cm.tab20
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for proper legend
for idx, label in enumerate(unique_labels):
mask = labels == label
ax.scatter(
coordinates[mask, 0],
coordinates[mask, 1],
c=[colormap(idx % 20)],
s=15,
label=label,
alpha=0.7,
)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(loc='best', markerscale=2)
# Save figure as vector PDF (scalable, ideal for publications)
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT embedding extraction and visualization.
Provides two main commands:
- extract: Generate embeddings from time series features
- plot: Visualize embeddings with t-SNE
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
out_dir: str = "./embeddings/REM",
batch_size: int = 32,
label: str = "REM"
) -> None:
"""
Extract embeddings from time series data.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM')
"""
embedder = SBERT()
dataset = embedder.extract_embeddings(data_root, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = "./embeddings/W",
out_dir: str = "./plots/W",
perplexity: float = 10.0,
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE.
Args:
emb_dir: Directory containing the HuggingFace embeddings dataset
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 10.0)
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
"""
os.makedirs(out_dir, exist_ok=True)
# Load saved embeddings dataset
dataset = SBERT.load_embeddings(emb_dir)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations
create_scatter_plot(
coordinates_2d,
np.array(dataset["label"]),
"t-SNE Visualization (Colored by Sleep Stage)",
os.path.join(out_dir, "tsne_by_label.pdf")
)
create_scatter_plot(
coordinates_2d,
np.array(dataset["user_id"]),
"t-SNE Visualization (Colored by User ID)",
os.path.join(out_dir, "tsne_by_user.pdf")
)
if __name__ == "__main__":
Fire(CLI)

View File

@@ -1,529 +0,0 @@
"""
SBERT Embedding Extraction and Visualization with Metadata (Age and Sex)
This module provides functionality to:
1. Extract SBERT embeddings from time series features
2. Visualize embeddings colored by subject metadata (age and sex) instead of user IDs
Features:
1. Extract embeddings from time series features using SBERT
2. Load embeddings from HuggingFace dataset
3. Load subject metadata from XLS file
4. Map user IDs to age and sex information
5. Visualize embeddings with t-SNE colored by age (continuous) and sex (categorical)
Usage:
# Extract embeddings for all labels
python gen_plot_metadata.py extract --data_root /path/to/data --out_dir ./embeddings/all_labels
# Extract embeddings for a single label
python gen_plot_metadata.py extract --data_root /path/to/data --out_dir ./embeddings/W --label W
# Visualize by age from single label directory
python gen_plot_metadata.py plot --emb_dir ./embeddings/W --out_dir ./plots/W --color_by age
# Visualize by age from all label directories (W, REM, N1, N2, N3)
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by age
# Visualize by sex from all labels
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by sex
# Create both age and sex plots from all labels
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by both
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset, concatenate_datasets
from fire import Fire
from sklearn.manifold import TSNE
from gen_plot import SBERT, reduce_to_2d_tsne
# =============================================================================
# Constants
# =============================================================================
SUBJECT_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/SC-subjects.xls"
# =============================================================================
# Metadata Loading
# =============================================================================
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
"""
Load subject metadata from XLS file.
The XLS file contains subject information including:
- subject: Subject ID (e.g., "SC4001", "SC4002")
- age: Age of the subject
- sex (F=1): Sex (1 = Female, 0 = Male)
Args:
subject_path: Path to the SC-subjects.xls file
Returns:
Dictionary mapping subject IDs to metadata dictionaries
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
"""
df = pd.read_excel(subject_path)
subject_info = {}
for index, row in df.iterrows():
subject_id = str(row["subject"]).strip()
subject_info[subject_id] = {
"age": int(row["age"]) if pd.notna(row["age"]) else None,
"sex": int(row["sex (F=1)"]) if pd.notna(row["sex (F=1)"]) else None,
}
return subject_info
def map_user_ids_to_metadata(
user_ids: np.ndarray,
subject_metadata: Dict[str, Dict[str, Any]]
) -> tuple:
"""
Map user IDs from dataset to age and sex metadata.
User IDs in the dataset are typically 2-digit codes (e.g., "40", "41").
Subject IDs in the metadata file are typically 4-character codes (e.g., "SC40", "SC41").
We need to match them appropriately.
Args:
user_ids: Array of user IDs from the dataset
subject_metadata: Dictionary of subject metadata loaded from XLS file
Returns:
Tuple of (ages, sexes) as numpy arrays
Missing values are set to None
"""
ages = []
sexes = []
for user_id in user_ids:
age = subject_metadata[user_id]["age"]
sex = subject_metadata[user_id]["sex"]
ages.append(age)
sexes.append(sex)
return np.array(ages), np.array(sexes)
# =============================================================================
# Visualization Functions
# =============================================================================
def create_scatter_plot_by_age(
coordinates: np.ndarray,
ages: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by age (continuous colormap).
Uses a continuous colormap (viridis) to show age distribution.
Ages are mapped to colors on a gradient scale.
Args:
coordinates: 2D array of shape (num_points, 2)
ages: Age values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing age data
# Convert None values to NaN for proper numpy handling
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
valid_mask = ~np.isnan(ages_float)
valid_coords = coordinates[valid_mask]
valid_ages = ages_float[valid_mask]
if len(valid_ages) == 0:
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
return
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Create scatter plot with continuous colormap
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_ages,
cmap='viridis', # Continuous colormap for granular age visualization
s=15,
alpha=0.7,
edgecolors='none',
)
# Add colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Age (years)', rotation=270, labelpad=20)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
def create_scatter_plot_by_sex(
coordinates: np.ndarray,
sexes: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by sex (categorical).
Uses discrete colors for different sex categories.
Sex encoding: 1 = Female, 0 = Male
Args:
coordinates: 2D array of shape (num_points, 2)
sexes: Sex values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing sex data
# Convert None values to NaN for proper numpy handling
sexes_float = np.array([float(s) if s is not None else np.nan for s in sexes])
valid_mask = ~np.isnan(sexes_float)
valid_coords = coordinates[valid_mask]
valid_sexes = sexes_float[valid_mask].astype(int)
if len(valid_sexes) == 0:
print(f"[WARN] No valid sex data found. Skipping plot: {output_path}")
return
# Map sex codes to labels
sex_labels = {0: "Male", 1: "Female"}
unique_sexes = sorted(set(valid_sexes))
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for proper legend
colors = ['steelblue', 'coral'] # Blue for Male, Coral for Female
for idx, sex_code in enumerate(unique_sexes):
mask = valid_sexes == sex_code
ax.scatter(
valid_coords[mask, 0],
valid_coords[mask, 1],
c=colors[sex_code % len(colors)],
s=15,
label=sex_labels.get(sex_code, f"Sex {sex_code}"),
alpha=0.7,
)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(loc='best', markerscale=2)
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
sex_counts = {sex_labels.get(s, f"Sex {s}"): (valid_sexes == s).sum() for s in unique_sexes}
print(f"[INFO] Sex distribution: {sex_counts}")
print(f"[INFO] Points with valid sex: {len(valid_sexes)}/{len(sexes)}")
# =============================================================================
# Embedding Extraction Utilities
# =============================================================================
def extract_embeddings_for_all_labels(
data_root: str,
out_dir: str,
batch_size: int = 32
) -> Dataset:
"""
Extract embeddings for all sleep stage labels and combine them.
Extracts embeddings for each label (W, REM, N1, N2, N3) separately,
then concatenates them into a single dataset.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for the combined HuggingFace dataset
batch_size: Batch size for inference (default: 32)
Returns:
Combined HuggingFace Dataset with all labels
"""
embedder = SBERT()
all_labels = ["W", "REM", "N1", "N2", "N3"]
datasets = []
for label in all_labels:
print(f"\n[INFO] Extracting embeddings for label: {label}")
dataset = embedder.extract_embeddings(data_root, batch_size, label)
print(f"[INFO] Extracted {len(dataset)} samples for label {label}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"\n[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
# Save combined dataset
embedder.save_embeddings(combined_dataset, out_dir)
return combined_dataset
# =============================================================================
# Data Loading Utilities
# =============================================================================
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
"""
Load embeddings from all label subdirectories and concatenate them.
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
and loads embeddings from each, then concatenates them into a single dataset.
Args:
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
Returns:
Concatenated HuggingFace Dataset with all labels combined
"""
# Discover all label subdirectories
label_dirs = []
for item in os.listdir(embeddings_root):
item_path = os.path.join(embeddings_root, item)
if os.path.isdir(item_path):
# Check if it's a valid HuggingFace dataset directory
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
label_dirs.append((item, item_path))
if len(label_dirs) == 0:
raise ValueError(
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
)
label_dirs.sort() # Sort for consistent ordering
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
# Load datasets from each label directory
datasets = []
for label_name, label_path in label_dirs:
print(f"[INFO] Loading embeddings from: {label_path}")
dataset = load_from_disk(label_path)
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return combined_dataset
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT embedding extraction and visualization with metadata.
Provides:
- extract: Generate embeddings from time series features
- plot: Visualize embeddings colored by age or sex instead of user ID
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
out_dir: str = "./embeddings/all_labels",
batch_size: int = 32,
label: Optional[str] = None
) -> None:
"""
Extract embeddings from time series features using SBERT.
Can extract for a single label or all labels at once.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
If None or 'all', extracts for all labels (default: None for all labels)
"""
embedder = SBERT()
if label is None or label == "all":
# Extract for all labels
print(f"[INFO] Extracting embeddings for all labels")
extract_embeddings_for_all_labels(data_root, out_dir, batch_size)
else:
# Extract for single label
print(f"[INFO] Extracting embeddings for label: {label}")
dataset = embedder.extract_embeddings(data_root, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = None,
embeddings_root: str = "./embeddings",
out_dir: str = "./plots/all_labels",
subject_path: str = SUBJECT_PATH,
perplexity: float = 10.0,
color_by: str = "age",
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE, colored by age or sex.
Can load from either a single label directory or all label directories.
Args:
emb_dir: Single directory containing the HuggingFace embeddings dataset
(e.g., "./embeddings/W"). If provided, only this directory is used.
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
Used only if emb_dir is not provided.
out_dir: Output directory for visualization plots (PDF)
subject_path: Path to SC-subjects.xls file with metadata
perplexity: t-SNE perplexity parameter (default: 10.0)
color_by: What to color by - 'age', 'sex', or 'both' (default: 'age')
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
This filters the already-loaded data, not which directories to load.
"""
# Validate color_by argument
if color_by not in ["age", "sex", "both", "all"]:
raise ValueError(f"Invalid color_by: {color_by}. Use 'age', 'sex', or 'both'.")
os.makedirs(out_dir, exist_ok=True)
# Load embeddings: either from single directory or all label directories
if emb_dir is not None:
# Load from single directory
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
dataset = SBERT.load_embeddings(emb_dir)
else:
# Load from all label directories
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
dataset = load_embeddings_from_all_labels(embeddings_root)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Load subject metadata
print(f"[INFO] Loading subject metadata from: {subject_path}")
subject_metadata = load_subject_metadata(subject_path)
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
# Map user IDs to metadata
user_ids = np.array([str(uid) for uid in dataset["user_id"]])
user_ids_ = [str(int(uid)) for uid in user_ids]
ages, sexes = map_user_ids_to_metadata(user_ids_, subject_metadata)
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations based on color_by parameter
if color_by == "age":
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
elif color_by == "sex":
create_scatter_plot_by_sex(
coordinates_2d,
sexes,
"t-SNE Visualization (Colored by Sex)",
os.path.join(out_dir, "tsne_by_sex.pdf")
)
elif color_by == "both" or color_by == "all":
# Create both plots
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
create_scatter_plot_by_sex(
coordinates_2d,
sexes,
"t-SNE Visualization (Colored by Sex)",
os.path.join(out_dir, "tsne_by_sex.pdf")
)
if __name__ == "__main__":
Fire(CLI)

View File

@@ -1,527 +0,0 @@
"""
User Classification from SBERT Embeddings
This module evaluates how well SBERT embeddings capture user-specific patterns
by training a simple linear classifier to predict user identity from embeddings.
Motivation:
If embeddings contain user-distinguishing information, a classifier should be able
to predict which user a time series belongs to. High accuracy suggests that the
embeddings capture individual characteristics.
Experimental Design:
- Task: Multi-class classification
- Model: Random Forest with ensemble of decision trees
- Captures non-linear relationships in embedding space
- Provides feature importance scores for interpretability
- Robust to overfitting through bagging and random feature selection
- Split Strategy:
1. Session-based: Train on session 1, test on session 2
2. Random: Standard train/test split
Session-based split is more challenging and realistic because:
- Tests whether user patterns are stable across different recording sessions
- Avoids data leakage from same-session samples in train and test
Random Forest Advantages over Logistic Regression:
- Handles non-linear decision boundaries
- Feature importance reveals which embedding dimensions matter most
- No assumption about data distribution
- Naturally handles multi-class classification
Output:
- Classification metrics
- Confusion matrix visualization
Usage:
python simple_user_classifier.py \\
--embeddings ./embeddings \\
--out_dir ./results
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from typing import Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset
from fire import Fire
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score,
f1_score,
classification_report,
confusion_matrix,
silhouette_score,
)
from sklearn.model_selection import train_test_split
# =============================================================================
# Data Loading
# =============================================================================
def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
if not os.path.isdir(embedding_path):
raise FileNotFoundError(
f"Dataset directory not found: {embedding_path}. "
"Ensure this is a valid HuggingFace dataset directory."
)
# Load HuggingFace dataset from disk
dataset = load_from_disk(embedding_path)
return dataset
# =============================================================================
# Data Splitting
# =============================================================================
def split_by_session(
features: np.ndarray,
labels: np.ndarray,
session_ids: np.ndarray,
train_session: str = "1",
test_session: str = "2",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
"""
Split data by recording session for temporal generalization evaluation.
This split strategy tests whether learned patterns generalize across time.
Training on session 1 and testing on session 2 simulates real-world deployment
where models must work on future recordings.
Args:
features: Feature matrix of shape (num_samples, num_features)
labels: Label array of shape (num_samples,)
session_ids: Session identifier for each sample
train_session: Session ID to use for training (default: "1")
test_session: Session ID to use for testing (default: "2")
Returns:
Tuple of (X_train, X_test, y_train, y_test, split_description)
"""
# Create boolean masks for train and test sets
train_mask = session_ids == train_session
test_mask = session_ids == test_session
# Apply masks to create train/test splits
X_train = features[train_mask]
X_test = features[test_mask]
y_train = labels[train_mask]
y_test = labels[test_mask]
split_description = f"session({train_session}->train, {test_session}->test)"
return X_train, X_test, y_train, y_test, split_description
def split_random(
features: np.ndarray,
labels: np.ndarray,
test_size: float = 0.2,
random_state: int = 0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
"""
Split data randomly with stratification for in-distribution evaluation.
Stratification ensures each class has proportional representation in
both train and test sets, preventing class imbalance issues.
Args:
features: Feature matrix of shape (num_samples, num_features)
labels: Label array of shape (num_samples,)
test_size: Fraction of data to use for testing (default: 0.2)
random_state: Random seed for reproducibility
Returns:
Tuple of (X_train, X_test, y_train, y_test, split_description)
"""
X_train, X_test, y_train, y_test = train_test_split(
features,
labels,
test_size=test_size,
random_state=random_state,
stratify=labels,
)
split_description = "random"
return X_train, X_test, y_train, y_test, split_description
# =============================================================================
# Model Training and Evaluation
# =============================================================================
def create_classifier_pipeline(
n_estimators: int = 200,
max_depth: int = None,
min_samples_split: int = 2,
min_samples_leaf: int = 1,
random_state: int = 0,
) -> Pipeline:
"""
Create a scikit-learn pipeline for user classification using Random Forest.
Pipeline Architecture:
----------------------
1. StandardScaler: Z-score normalization of features
- Centers features (mean=0) and scales to unit variance (std=1)
- While Random Forest is scale-invariant, scaling helps with
consistent feature importance interpretation
2. RandomForestClassifier: Ensemble of decision trees
- Builds multiple decision trees on random subsets of data (bagging)
- Each tree uses random subset of features at each split
- Final prediction is majority vote across all trees
- Provides feature_importances_ for interpretability
Random Forest Hyperparameters:
------------------------------
- n_estimators: Number of trees in the forest (more = better but slower)
- max_depth: Maximum tree depth (None = expand until pure leaves)
- min_samples_split: Minimum samples to split internal node
- min_samples_leaf: Minimum samples required at leaf node
Args:
n_estimators: Number of trees (default: 200)
max_depth: Maximum depth of trees (default: None, fully grown)
min_samples_split: Min samples for splitting (default: 2)
min_samples_leaf: Min samples at leaf (default: 1)
random_state: Random seed for reproducibility
Returns:
Configured sklearn Pipeline ready for .fit() and .predict()
"""
pipeline = Pipeline([
# Step 1: Feature normalization
("scaler", StandardScaler(
with_mean=True, # Subtract mean (center the data)
with_std=True, # Divide by standard deviation (scale to unit variance)
)),
# Step 2: Random Forest classification
("classifier", RandomForestClassifier(
n_estimators=n_estimators, # Number of trees in the forest
max_depth=max_depth, # Maximum depth (None = unlimited)
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
n_jobs=-1, # Use all CPU cores for parallel tree building
random_state=random_state, # For reproducibility
class_weight="balanced", # Handle class imbalance automatically
oob_score=True, # Enable out-of-bag error estimation
)),
])
return pipeline
def evaluate_classifier(
y_true: np.ndarray,
y_pred: np.ndarray,
) -> Tuple[float, float, str, np.ndarray, list]:
"""
Compute classification metrics and confusion matrix.
Metrics Computed:
- Accuracy: Overall fraction of correct predictions
- Macro F1: Average F1 across all classes
- Per-class report: Precision, recall, F1 for each user
- Confusion matrix: Detailed breakdown of predictions vs ground truth
Args:
y_true: Ground truth labels
y_pred: Predicted labels
Returns:
Tuple of (accuracy, f1, classification_report, confusion_matrix, class_labels)
"""
# Compute scalar metrics
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro') # Multiclass: use macro-averaged F1
# Generate detailed per-class report
report = classification_report(y_true, y_pred, digits=4)
# Compute confusion matrix: Get all unique classes from both true and predicted labels
class_labels = sorted(set(y_true) | set(y_pred))
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
return accuracy, f1, report, cm, class_labels
# =============================================================================
# Visualization
# =============================================================================
def save_confusion_matrix_plot(
confusion_mat: np.ndarray,
class_labels: list,
output_path: str,
) -> None:
"""
Create and save a confusion matrix heatmap visualization.
The confusion matrix shows:
- Rows: True class labels
- Columns: Predicted class labels
- Cell values: Count of samples with that (true, predicted) combination
- Diagonal: Correct predictions
- Off-diagonal: Misclassifications
Args:
confusion_mat: Square matrix of shape (num_classes, num_classes)
class_labels: List of class names for axis labels
output_path: File path to save the plot (PDF format recommended)
"""
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot heatmap using imshow
im = ax.imshow(confusion_mat, aspect="auto", cmap="Blues")
# Add colorbar to show value scale
plt.colorbar(im, ax=ax)
# Set labels and title
ax.set_title("Confusion Matrix (User Classification)")
ax.set_xlabel("Predicted User")
ax.set_ylabel("True User")
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved confusion matrix: {output_path}")
def save_metrics_report(
output_path: str,
split_description: str,
train_size: int,
test_size: int,
accuracy: float,
macro_f1: float,
classification_report_str: str,
oob_score: float = None,
silhouette: float = None,
) -> None:
"""
Save classification metrics to a text file.
Writes summary statistics (split strategy, sizes, scores) and detailed
per-class classification report to a text file.
Args:
output_path: File path to save the metrics report
split_description: Description of train/test split strategy
train_size: Number of training samples
test_size: Number of test samples
accuracy: Overall classification accuracy
macro_f1: Macro-averaged F1 score
classification_report_str: Detailed per-class report string
oob_score: Out-of-bag score from Random Forest (optional)
silhouette: Silhouette score for embedding quality (optional)
"""
with open(output_path, "w", encoding="utf-8") as f:
# Write summary statistics
f.write(f"split_used : {split_description}\n")
f.write(f"train_size : {train_size}\n")
f.write(f"test_size : {test_size}\n")
# Silhouette score (embedding quality metric)
if silhouette is not None:
f.write(f"silhouette : {silhouette:.4f}\n")
f.write(f"accuracy : {accuracy:.4f}\n")
f.write(f"macro_f1 : {macro_f1:.4f}\n")
# Add OOB score if available (Random Forest specific)
if oob_score is not None:
f.write(f"oob_score : {oob_score:.4f}\n")
f.write("\n")
# Write detailed per-class report
f.write(classification_report_str)
print(f"[DONE] Saved metrics report: {output_path}")
def save_feature_importance_plot(
feature_importances: np.ndarray,
output_path: str,
top_k: int = 50,
) -> None:
"""
Create and save a horizontal bar chart of top-k most important features.
Visualizes which embedding dimensions contribute most to user classification.
Higher importance indicates that dimension better distinguishes between users.
Args:
feature_importances: Array of feature importance scores from Random Forest
output_path: File path to save the plot (PDF format recommended)
top_k: Number of top features to display (default: 50)
"""
# Get indices of top-k most important features
top_indices = np.argsort(feature_importances)[::-1][:top_k]
top_importances = feature_importances[top_indices]
# Create figure
fig, ax = plt.subplots(figsize=(12, 8))
# Create horizontal bar chart
y_positions = np.arange(len(top_indices))
ax.barh(y_positions, top_importances, color="steelblue", alpha=0.8)
# Set labels
ax.set_yticks(y_positions)
ax.set_yticklabels([f"emb_{i}" for i in top_indices], fontsize=8)
ax.invert_yaxis() # Highest importance at top
ax.set_xlabel("Feature Importance (Mean Decrease in Impurity)")
ax.set_ylabel("Embedding Dimension")
ax.set_title(f"Top {top_k} Most Important Embedding Dimensions")
# Save as vector PDF
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved feature importance plot: {output_path}")
def save_feature_importance_csv(
feature_importances: np.ndarray,
output_path: str,
) -> None:
"""
Save feature importance scores to CSV file with ranking.
Creates a CSV with columns: rank, feature, importance
Features are sorted by importance in descending order.
Args:
feature_importances: Array of feature importance scores from Random Forest
output_path: File path to save the CSV file
"""
importance_df = pd.DataFrame({
"feature": [f"emb_{i}" for i in range(len(feature_importances))],
"importance": feature_importances,
})
importance_df = importance_df.sort_values("importance", ascending=False)
importance_df = importance_df.reset_index(drop=True)
importance_df["rank"] = range(1, len(importance_df) + 1)
importance_df = importance_df[["rank", "feature", "importance"]]
importance_df.to_csv(output_path, index=False)
print(f"[DONE] Saved feature importance CSV: {output_path}")
# =============================================================================
# Command Line Interface
# =============================================================================
def main(
embeddings: str = "./embeddings/REM",
out_dir: str = "./results/REM",
test_size: float = 0.2,
n_estimators: int = 200,
max_depth: int = None,
) -> None:
# Create output directory
os.makedirs(out_dir, exist_ok=True)
print(f"[INFO] Loading embeddings from: {embeddings}")
dataset = load_embeddings_with_metadata(embeddings)
X = np.array(dataset["embedding"], dtype=np.float32)
y = np.array([str(uid) for uid in dataset["user_id"]])
session_ids = np.array([str(sid) for sid in dataset["session_id"]])
print(f"[INFO] Loaded {len(y)} samples with {X.shape[1]}-dimensional embeddings")
print(f"[INFO] Number of unique users: {len(np.unique(y))}")
has_session_1 = (session_ids == "1").sum() > 0
has_session_2 = (session_ids == "2").sum() > 0
if has_session_1 and has_session_2:
X_train, X_test, y_train, y_test, split_desc = split_by_session(
X, y, session_ids
)
else:
print("[WARN] Missing session data, falling back to random split")
X_train, X_test, y_train, y_test, split_desc = split_random(
X, y, test_size
)
split_desc = "random(fallback)"
print(f"[INFO] Split: {split_desc}")
print(f"[INFO] Train size: {len(y_train)}, Test size: {len(y_test)}")
silhouette_avg = silhouette_score(X, y)
print(f"[INFO] Silhouette Score (user clusters): {silhouette_avg:.4f}")
print("[INFO] Training Random Forest classifier...")
print(f"[INFO] Hyperparameters: n_estimators={n_estimators}, max_depth={max_depth}")
classifier = create_classifier_pipeline(
n_estimators=n_estimators,
max_depth=max_depth,
)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
rf_model = classifier.named_steps["classifier"]
oob_score = rf_model.oob_score_ if hasattr(rf_model, "oob_score_") else None
feature_importances = rf_model.feature_importances_
print("[INFO] Evaluating classifier performance...")
accuracy, macro_f1, report, cm, classes = evaluate_classifier(y_test, y_pred)
metrics_path = os.path.join(out_dir, "user_cls_metrics.txt")
save_metrics_report(
output_path=metrics_path,
split_description=split_desc,
train_size=len(y_train),
test_size=len(y_test),
accuracy=accuracy,
macro_f1=macro_f1,
classification_report_str=report,
oob_score=oob_score,
silhouette=silhouette_avg,
)
confusion_path = os.path.join(out_dir, "user_cls_confusion.pdf")
save_confusion_matrix_plot(cm, classes, confusion_path)
importance_plot_path = os.path.join(out_dir, "feature_importance.pdf")
save_feature_importance_plot(feature_importances, importance_plot_path)
importance_csv_path = os.path.join(out_dir, "feature_importance.csv")
save_feature_importance_csv(feature_importances, importance_csv_path)
print("\n" + "=" * 50)
print("RANDOM FOREST CLASSIFICATION RESULTS")
print("=" * 50)
print(f"Split Strategy : {split_desc}")
print(f"Silhouette : {silhouette_avg:.4f}")
print(f"Accuracy : {accuracy:.4f}")
print(f"Macro F1 : {macro_f1:.4f}")
if oob_score is not None:
print(f"OOB Score : {oob_score:.4f}")
print(f"Top 5 Important Features:")
top5_idx = np.argsort(feature_importances)[::-1][:5]
for rank, idx in enumerate(top5_idx, 1):
print(f" {rank}. emb_{idx}: {feature_importances[idx]:.4f}")
print("=" * 50)
if __name__ == "__main__":
Fire(main)

View File

@@ -1,756 +0,0 @@
"""
SBERT Metadata-Based Embedding Extraction and Visualization Pipeline
This module provides functionality to
1. Extract embeddings from user metadata (age and sex) using SBERT (Sentence-BERT)
2. Visualize embeddings using dimensionality reduction (t-SNE) colored by age
SBERT Overview:
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
and triplet network structures to derive semantically meaningful sentence
embeddings. It's designed for semantic similarity tasks and produces
fixed-size dense vector representations of text.
Key Features:
- Semantic understanding: Captures meaning rather than just word presence
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
- Efficient: Optimized for sentence-level tasks
Embedding Strategy:
Instead of using time series features, we create textual descriptions based on
user metadata (age and sex). This approach allows us to capture user-level
characteristics in the embedding space.
Processing Pipeline:
1. Load user metadata from XLS file (age, sex)
2. Textualization: Convert metadata to natural language description
3. SBERT encoding: Generate 384-dimensional embeddings
4. Visualization: t-SNE with continuous age coloring
Usage:
# Extract embeddings from metadata for all labels
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/all_labels
# Extract embeddings for a single label
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/REM --label REM
# Visualize with t-SNE from all label directories (colored by age)
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by age
# Visualize by sex from all labels
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by sex
# Create both age and sex plots from all labels
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by both
# Visualize from a single label directory
python gen_plot.py plot --emb_dir ./embeddings/REM --out_dir ./plots/REM --color_by age
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset, concatenate_datasets
from fire import Fire
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
# =============================================================================
# Constants
# =============================================================================
SUBJECT_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/SC-subjects.xls"
# =============================================================================
# Metadata Loading
# =============================================================================
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
"""
Load subject metadata from XLS file.
The XLS file contains subject information including:
- subject: Subject ID (e.g., "SC4001", "SC4002")
- age: Age of the subject
- sex (F=1): Sex (1 = Female, 0 = Male)
Args:
subject_path: Path to the SC-subjects.xls file
Returns:
Dictionary mapping subject IDs to metadata dictionaries
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
"""
df = pd.read_excel(subject_path)
subject_info = {}
for index, row in df.iterrows():
subject_id = str(row["subject"]).strip()
subject_info[subject_id] = {
"age": int(row["age"]) if pd.notna(row["age"]) else None,
"sex": int(row["sex (F=1)"]) if pd.notna(row["sex (F=1)"]) else None,
}
return subject_info
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class SBERT_Metadata:
"""
Extracts fixed-dimensional embeddings from user metadata using SBERT.
Uses Sentence-BERT to convert textualized metadata descriptions into dense
vector representations. The textualization process converts age and sex
information into natural language, which SBERT then encodes semantically.
Architecture:
User Metadata → Textualization → SBERT Encoder → Embedding
Processing Pipeline:
1. Load user metadata (age, sex) from XLS file
2. Textualize: Convert metadata to natural language description
3. SBERT encoding: Generate 384-dimensional semantic embeddings
4. Output: Fixed-size embedding vector per user
Model: all-MiniLM-L6-v2
- Lightweight BERT variant optimized for sentence embeddings
- 384-dimensional output embeddings
- Fast inference with good semantic understanding
Attributes:
model: SentenceTransformer instance for encoding textualized metadata
"""
def __init__(self):
"""
Initialize the SBERT embedder with pre-trained model.
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
optimized for sentence similarity tasks.
"""
self.model = SentenceTransformer("all-MiniLM-L6-v2")
@staticmethod
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
"""
Discover all user/session directories under data_root.
Uses glob pattern matching for cleaner directory traversal.
Expected structure: data_root/user_id/session_id/
Args:
data_root: Root directory containing user/session subfolders
Returns:
List of (user_id, session_id, session_path) tuples
"""
discovered_paths = []
# Use glob to find all session directories (2 levels deep)
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
if not os.path.isdir(session_path):
continue
# Extract user_id and session_id from path
session_id = os.path.basename(session_path)
user_id = os.path.basename(os.path.dirname(session_path))
discovered_paths.append((user_id, session_id, session_path))
return discovered_paths
def textualize_metadata(self, age: Optional[int], sex: Optional[int]) -> str:
"""
Convert user metadata (age and sex) into a natural language description.
This textualization step is crucial for SBERT, which expects text input.
The description provides user demographic information in a structured format.
Args:
age: Age of the user (integer, may be None)
sex: Sex of the user (0 = Male, 1 = Female, may be None)
Returns:
Natural language string describing the user metadata
"""
# Map sex code to text
if sex is not None:
sex_text = "Female" if sex == 1 else "Male"
else:
sex_text = "Unknown"
# Create sentence from metadata
if age is not None:
sentence = f"This is the information of the user, age: {age}, sex: {sex_text}."
else:
sentence = f"This is the information of the user, age: unknown, sex: {sex_text}."
return sentence
def compute_embedding_from_metadata(
self,
ages: List[Optional[int]],
sexes: List[Optional[int]]
) -> np.ndarray:
"""
Generate embedding vectors from metadata using SBERT.
Processing Pipeline:
1. Textualize each user's metadata into natural language
2. Encode textual descriptions using SBERT
3. Return fixed-size embedding vectors
Args:
ages: List of age values (may contain None)
sexes: List of sex values (0 = Male, 1 = Female, may contain None)
Returns:
Embedding array of shape (batch_size, embedding_dim)
For all-MiniLM-L6-v2: (batch_size, 384)
"""
# Convert metadata to text sentences
text_samples = []
for age, sex in zip(ages, sexes):
text_samples.append(self.textualize_metadata(age, sex))
# Encode text descriptions using SBERT
# Returns numpy array of shape (batch_size, 384)
embeddings = self.model.encode(text_samples)
return embeddings
def extract_embeddings(
self,
data_root: str,
subject_path: str = SUBJECT_PATH,
batch_size: int = 32,
label: Optional[str] = None
) -> Dataset:
"""
Extract embeddings from user metadata for all sessions.
Iterates through all user/session combinations, loads metadata for each user,
generates embeddings from metadata sentences, and aggregates results.
Can process a single label or all labels.
Args:
data_root: Root directory containing user/session data folders
subject_path: Path to SC-subjects.xls file with metadata
batch_size: Number of samples to process together (for batching embeddings)
Larger = faster but more memory.
32 is a good balance for most systems.
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM").
If None or "all", processes all labels.
Returns:
HuggingFace Dataset with columns:
- user_id, session_id, idx, label (metadata from original data)
- embedding (384-dim vector from all-MiniLM-L6-v2 based on metadata)
"""
# Load subject metadata
print(f"[INFO] Loading subject metadata from: {subject_path}")
subject_metadata = load_subject_metadata(subject_path)
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
session_paths = self.discover_session_paths(data_root)
print(f"[INFO] Discovered {len(session_paths)} sessions")
all_embeddings = []
all_user_ids = []
all_session_ids = []
all_idxs = []
all_labels = []
all_ages = []
all_sexes = []
# Collect metadata for all samples first
for user_id, session_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# Shuffle dataset for randomness
dataset = dataset.shuffle(seed=0)
# Filter by sleep stage label if specified
if label is not None and label != "all":
dataset = dataset.filter(lambda x: x["label"] == label)
num_samples = len(dataset)
if num_samples == 0:
continue
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Get metadata for this user
# Convert user_id to string format that matches metadata keys
user_id_str = str(int(user_id))
try:
age = subject_metadata[user_id_str]["age"]
sex = subject_metadata[user_id_str]["sex"]
except KeyError:
age = None
sex = None
print(f"[WARN] No metadata found for user_id: {user_id_str}")
# Collect all samples for this user/session
for i in range(num_samples):
all_user_ids.append(str(dataset["user_id"][i]))
all_session_ids.append(str(dataset["session_id"][i]))
all_idxs.append(int(dataset["idx"][i]))
all_labels.append(str(dataset["label"][i]))
all_ages.append(age)
all_sexes.append(sex)
# Generate embeddings from metadata in batches
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
for batch_start in range(0, len(all_ages), batch_size):
batch_end = min(batch_start + batch_size, len(all_ages))
batch_ages = all_ages[batch_start:batch_end]
batch_sexes = all_sexes[batch_start:batch_end]
# Compute embeddings from metadata
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
# Collect embeddings
for i in range(embeddings.shape[0]):
all_embeddings.append(embeddings[i].tolist())
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
"session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"age": all_ages,
"sex": all_sexes,
"embedding": all_embeddings,
})
print(f"[INFO] Total samples: {len(result_dataset)}")
# Print label distribution
if "label" in result_dataset.column_names:
label_counts = {}
for lbl in result_dataset["label"]:
label_counts[lbl] = label_counts.get(lbl, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return result_dataset
def save_embeddings(
self,
dataset: Dataset,
output_dir: str
) -> None:
"""
Save embeddings dataset to disk in HuggingFace format.
Args:
dataset: HuggingFace Dataset containing embeddings and metadata
output_dir: Directory path to save the dataset
"""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load saved embeddings dataset from disk.
Args:
embedding_dir: Directory path containing the saved HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Data Loading Utilities
# =============================================================================
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
"""
Load embeddings from all label subdirectories and concatenate them.
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
and loads embeddings from each, then concatenates them into a single dataset.
Args:
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
Returns:
Concatenated HuggingFace Dataset with all labels combined
"""
# Discover all label subdirectories
label_dirs = []
for item in os.listdir(embeddings_root):
item_path = os.path.join(embeddings_root, item)
if os.path.isdir(item_path):
# Check if it's a valid HuggingFace dataset directory
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
label_dirs.append((item, item_path))
if len(label_dirs) == 0:
raise ValueError(
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
)
label_dirs.sort() # Sort for consistent ordering
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
# Load datasets from each label directory
datasets = []
for label_name, label_path in label_dirs:
print(f"[INFO] Loading embeddings from: {label_path}")
dataset = load_from_disk(label_path)
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return combined_dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
dimensionality reduction technique that preserves local structure.
Points that are similar in high dimensions stay close in 2D.
Args:
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity parameter (typically 5-50).
Higher values consider more neighbors, creating smoother layouts.
Rule of thumb: perplexity ~ sqrt(num_samples)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0, # For reproducibility
perplexity=perplexity,
max_iter=1000, # Usually sufficient for convergence
init='random',
learning_rate='auto', # Let sklearn choose optimal learning rate
)
return tsne.fit_transform(embeddings)
def create_scatter_plot_by_age(
coordinates: np.ndarray,
ages: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by age (continuous colormap).
Uses a continuous colormap (viridis) to show age distribution.
Ages are mapped to colors on a gradient scale.
Args:
coordinates: 2D array of shape (num_points, 2)
ages: Age values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing age data
# Convert None values to NaN for proper numpy handling
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
valid_mask = ~np.isnan(ages_float)
valid_coords = coordinates[valid_mask]
valid_ages = ages_float[valid_mask]
if len(valid_ages) == 0:
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
return
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Create scatter plot with continuous colormap
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_ages,
cmap='viridis', # Continuous colormap for granular age visualization
s=15,
alpha=0.7,
edgecolors='none',
)
# Add colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Age (years)', rotation=270, labelpad=20)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
def create_scatter_plot_by_sex(
coordinates: np.ndarray,
sexes: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by sex (categorical).
Uses discrete colors for different sex categories.
Sex encoding: 1 = Female, 0 = Male
Args:
coordinates: 2D array of shape (num_points, 2)
sexes: Sex values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing sex data
# Convert None values to NaN for proper numpy handling
sexes_float = np.array([float(s) if s is not None else np.nan for s in sexes])
valid_mask = ~np.isnan(sexes_float)
valid_coords = coordinates[valid_mask]
valid_sexes = sexes_float[valid_mask].astype(int)
if len(valid_sexes) == 0:
print(f"[WARN] No valid sex data found. Skipping plot: {output_path}")
return
# Map sex codes to labels
sex_labels = {0: "Male", 1: "Female"}
unique_sexes = sorted(set(valid_sexes))
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for proper legend
colors = ['steelblue', 'coral'] # Blue for Male, Coral for Female
for idx, sex_code in enumerate(unique_sexes):
mask = valid_sexes == sex_code
ax.scatter(
valid_coords[mask, 0],
valid_coords[mask, 1],
c=colors[sex_code % len(colors)],
s=15,
label=sex_labels.get(sex_code, f"Sex {sex_code}"),
alpha=0.7,
)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(loc='best', markerscale=2)
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
sex_counts = {sex_labels.get(s, f"Sex {s}"): (valid_sexes == s).sum() for s in unique_sexes}
print(f"[INFO] Sex distribution: {sex_counts}")
print(f"[INFO] Points with valid sex: {len(valid_sexes)}/{len(sexes)}")
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT metadata-based embedding extraction and visualization.
Provides two main commands:
- extract: Generate embeddings from user metadata (age, sex)
- plot: Visualize embeddings with t-SNE colored by age or sex
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
subject_path: str = SUBJECT_PATH,
out_dir: str = "./embeddings/all_labels",
batch_size: int = 32,
label: str = None
) -> None:
"""
Extract embeddings from user metadata.
Args:
data_root: Root directory containing user/session data folders
subject_path: Path to SC-subjects.xls file with metadata
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
If None or 'all', processes all labels (default: None for all labels)
"""
embedder = SBERT_Metadata()
dataset = embedder.extract_embeddings(data_root, subject_path, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = None,
embeddings_root: str = "./embeddings",
out_dir: str = "./plots/all_labels",
perplexity: float = 10.0,
color_by: str = "age",
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE, colored by age or sex.
Can load from either a single label directory or all label directories.
Args:
emb_dir: Single directory containing the HuggingFace embeddings dataset
(e.g., "./embeddings/REM"). If provided, only this directory is used.
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
Used only if emb_dir is not provided.
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 10.0)
color_by: What to color by - 'age', 'sex', or 'both' (default: 'age')
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
This filters the already-loaded data, not which directories to load.
"""
# Validate color_by argument
if color_by not in ["age", "sex", "both", "all"]:
raise ValueError(f"Invalid color_by: {color_by}. Use 'age', 'sex', or 'both'.")
os.makedirs(out_dir, exist_ok=True)
# Load embeddings: either from single directory or all label directories
if emb_dir is not None:
# Load from single directory
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
dataset = SBERT_Metadata.load_embeddings(emb_dir)
else:
# Load from all label directories
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
dataset = load_embeddings_from_all_labels(embeddings_root)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Extract ages and sexes for coloring
ages = np.array(dataset["age"])
sexes = np.array(dataset["sex"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations based on color_by parameter
if color_by == "age":
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
elif color_by == "sex":
create_scatter_plot_by_sex(
coordinates_2d,
sexes,
"t-SNE Visualization (Colored by Sex)",
os.path.join(out_dir, "tsne_by_sex.pdf")
)
elif color_by == "both" or color_by == "all":
# Create both plots
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
create_scatter_plot_by_sex(
coordinates_2d,
sexes,
"t-SNE Visualization (Colored by Sex)",
os.path.join(out_dir, "tsne_by_sex.pdf")
)
if __name__ == "__main__":
Fire(CLI)

File diff suppressed because it is too large Load Diff

58
baselines/common.py Normal file
View File

@@ -0,0 +1,58 @@
"""Shared setup for baseline experiments."""
import os
import sys
import yaml
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, PROJECT_ROOT)
from core.data_loader import DataLoader # noqa: E402
from core.model import load_models # noqa: E402
from core.recruiter import Recruiter # noqa: E402
from core.agent import Agent # noqa: E402
from core.logger import Logger # noqa: E402
from core.prompt import gen_system_message, gen_task_message # noqa: E402
from core.json_utils import safe_parse_json # noqa: E402
from core.scores import self_certainty # noqa: E402
from core.vote import borda_vote # noqa: E402
def load_config(config_path: str) -> dict:
with open(config_path, "r", encoding="utf-8") as f:
return yaml.load(f, Loader=yaml.SafeLoader)
def setup(config: dict, temperature: float = None):
"""
Initialize all components from config.
Args:
config: Parsed YAML config dict.
temperature: Override model temperature (e.g. 0.7 for SC baseline).
If None, uses config["temperature"] or 0.0.
Returns:
(logger, dataloader, recruiter, agent)
"""
logger = Logger(config.get("log_path"))
logger.log_config(config)
dataloader = DataLoader(config.get("data_path"), config.get("target_user"))
recruiter = Recruiter(
source_dataset=dataloader.get_source_dataset(),
source_users=dataloader.get_source_users(),
num_shot=config.get("num_shot"),
classes=dataloader.get_classes(),
logger=logger,
)
temp = temperature if temperature is not None else config.get("temperature", 0.0)
model_pool = load_models(config.get("model_paths"), temperature=temp)
agent = Agent(model_pool=model_pool, logger=logger)
system_message = gen_system_message(metadata=dataloader.get_task_metadata())
agent.set_system_message(system_message)
return logger, dataloader, recruiter, agent

View File

@@ -0,0 +1,72 @@
"""
Baseline 4: Random Examples + Borda Voting
For each sample, randomly recruit queue_size fresh example sets.
Each example set produces one inference. Answers aggregated via borda voting.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config, setup,
gen_task_message, safe_parse_json, self_certainty, borda_vote,
)
async def run(config_path: str):
config = load_config(config_path)
queue_size = config.get("queue_size")
logger, dataloader, recruiter, agent = setup(config)
logger.log("[Baseline] Random examples + borda voting")
async def process(sample_idx, example_idx, example_set, sample):
try:
task_message = gen_task_message(sample, example_set)
response, logprobs = await agent.solve(
task_message, sample_idx, example_idx
)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(
f"[Main] Done {sample_idx} - {example_idx}: "
f"answer={answer}, score={score:.4f}"
)
return {"answer": answer, "score": score}
except Exception as e:
logger.log(
f"[Main] Error {sample_idx} - {example_idx}: {e}",
filename="errors.txt",
)
return {"answer": None, "score": float("-inf")}
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
example_sets = recruiter.recruit(queue_size)
tasks = [
process(idx, i, es, sample) for i, es in enumerate(example_sets)
]
results = await asyncio.gather(*tasks)
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
ground_truth = sample["label"]
if winner is not None:
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
logger.log(f"[Vote] votes={{ {tally_str} }}")
logger.log_result(idx, winner, ground_truth)
else:
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,69 @@
"""
Baseline 6: Random Dynamic Self-Consistency
For each sample, randomly recruit ONE fresh example set,
then run that same prompt queue_size times with temperature=0.7.
Answers are aggregated via borda voting.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config, setup,
gen_task_message, safe_parse_json, self_certainty, borda_vote,
)
async def run(config_path: str):
config = load_config(config_path)
queue_size = config.get("queue_size")
logger, dataloader, recruiter, agent = setup(config, temperature=0.7)
logger.log("[Baseline] Random dynamic example + self-consistency (temp=0.7)")
async def run_once(sample_idx, run_idx, task_message):
try:
response, logprobs = await agent.solve(task_message, sample_idx, run_idx)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(
f"[SC] {sample_idx} - run {run_idx}: "
f"answer={answer}, score={score:.4f}"
)
return {"answer": answer, "score": score}
except Exception as e:
logger.log(
f"[SC] Error {sample_idx} - run {run_idx}: {e}",
filename="errors.txt",
)
return {"answer": None, "score": float("-inf")}
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
example_set = recruiter.recruit(1)[0]
task_message = gen_task_message(sample, example_set)
tasks = [run_once(idx, i, task_message) for i in range(queue_size)]
results = await asyncio.gather(*tasks)
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
ground_truth = sample["label"]
if winner is not None:
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
logger.log(f"[Vote] votes={{ {tally_str} }}")
logger.log_result(idx, winner, ground_truth)
else:
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,51 @@
"""
Baseline 2: Random Dynamic Single Example
For each sample, randomly select a fresh example set.
Single inference per sample, no voting.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config, setup,
gen_task_message, safe_parse_json, self_certainty,
)
async def run(config_path: str):
config = load_config(config_path)
logger, dataloader, recruiter, agent = setup(config)
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
ground_truth = sample["label"]
try:
example_set = recruiter.recruit(1)[0]
task_message = gen_task_message(sample, example_set)
response, logprobs = await agent.solve(task_message, idx, 0)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(f"[Main] answer={answer}, score={score:.4f}")
except Exception as e:
logger.log(f"[Main] Error sample {idx}: {e}", filename="errors.txt")
answer = None
if answer is not None:
logger.log_result(idx, answer, ground_truth)
else:
logger.log(f"[Main] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,75 @@
"""
Baseline 5: Fixed Examples + Borda Voting
Recruit queue_size example sets once at the start.
For each sample, run all example sets through the LLM.
Answers aggregated via borda voting. No queue updates.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config,
setup,
gen_task_message,
safe_parse_json,
self_certainty,
borda_vote,
)
async def run(config_path: str):
config = load_config(config_path)
queue_size = config.get("queue_size")
logger, dataloader, recruiter, agent = setup(config)
example_sets = recruiter.recruit(queue_size)
logger.log(f"[Baseline] Fixed {queue_size} example set(s) + borda voting")
async def process(sample_idx, example_idx, example_set, sample):
try:
task_message = gen_task_message(sample, example_set)
response, logprobs = await agent.solve(
task_message, sample_idx, example_idx
)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(
f"[Main] Done {sample_idx} - {example_idx}: "
f"answer={answer}, score={score:.4f}"
)
return {"answer": answer, "score": score}
except Exception as e:
logger.log(
f"[Main] Error {sample_idx} - {example_idx}: {e}",
filename="errors.txt",
)
return {"answer": None, "score": float("-inf")}
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
tasks = [process(idx, i, es, sample) for i, es in enumerate(example_sets)]
results = await asyncio.gather(*tasks)
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
ground_truth = sample["label"]
if winner is not None:
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
logger.log(f"[Vote] votes={{ {tally_str} }}")
logger.log_result(idx, winner, ground_truth)
else:
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,69 @@
"""
Baseline 3: Self-Consistency with Fixed Example
One randomly selected example set, used for ALL samples.
For each sample, the SAME prompt is run queue_size times with temperature=0.7
across different LLM instances. Answers are aggregated via borda voting.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config, setup,
gen_task_message, safe_parse_json, self_certainty, borda_vote,
)
async def run(config_path: str):
config = load_config(config_path)
queue_size = config.get("queue_size")
logger, dataloader, recruiter, agent = setup(config, temperature=0.7)
example_set = recruiter.recruit(1)[0]
logger.log("[Baseline] Fixed example set + self-consistency (temp=0.7)")
async def run_once(sample_idx, run_idx, task_message):
try:
response, logprobs = await agent.solve(task_message, sample_idx, run_idx)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(
f"[SC] {sample_idx} - run {run_idx}: "
f"answer={answer}, score={score:.4f}"
)
return {"answer": answer, "score": score}
except Exception as e:
logger.log(
f"[SC] Error {sample_idx} - run {run_idx}: {e}",
filename="errors.txt",
)
return {"answer": None, "score": float("-inf")}
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
task_message = gen_task_message(sample, example_set)
tasks = [run_once(idx, i, task_message) for i in range(queue_size)]
results = await asyncio.gather(*tasks)
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
ground_truth = sample["label"]
if winner is not None:
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
logger.log(f"[Vote] votes={{ {tally_str} }}")
logger.log_result(idx, winner, ground_truth)
else:
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,53 @@
"""
Baseline 1: Fixed Single Example
One randomly selected example set, used for ALL samples.
Single inference per sample, no voting.
"""
import asyncio
import time
from fire import Fire
from common import (
load_config, setup,
gen_task_message, safe_parse_json, self_certainty,
)
async def run(config_path: str):
config = load_config(config_path)
logger, dataloader, recruiter, agent = setup(config)
example_set = recruiter.recruit(1)[0]
logger.log("[Baseline] Fixed single example set selected")
start_time = time.time()
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
ground_truth = sample["label"]
try:
task_message = gen_task_message(sample, example_set)
response, logprobs = await agent.solve(task_message, idx, 0)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(f"[Main] answer={answer}, score={score:.4f}")
except Exception as e:
logger.log(f"[Main] Error sample {idx}: {e}", filename="errors.txt")
answer = None
if answer is not None:
logger.log_result(idx, answer, ground_truth)
else:
logger.log(f"[Main] Sample {idx}: no valid answer, skipping")
logger.report(elapsed_seconds=time.time() - start_time)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire(main)

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/ours/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/ours/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/ours/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_borda/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_borda/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_borda/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_sc/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_sc/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_sc/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_single/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_single/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_single/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_borda/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_borda/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_borda/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_sc/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_sc/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_sc/sleepedf/02

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '00'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_single/sleepedf/00

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '01'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_single/sleepedf/01

View File

@@ -0,0 +1,17 @@
queue_size: 5
num_shot: 1
update_size: 3
borda_p: 1.0
vocab_size: 200064
model_paths:
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: '02'
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_single/sleepedf/02

View File

@@ -1,32 +0,0 @@
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
num_seeds: 10
num_examples: 1
# Selection criteria: out_random | in_random | out_similar
selection_criteria: "out_random"
# Embedding path (not required for random criteria)
# embedding_path: "/home/ssum/tsllm_personalization_icl/embeddings_20users"
models:
- ollama:url:joy.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11444/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
# chronos_base : "/mnt/sting/ssum/sleepedf_chronos_base_result"
# out_random : "/mnt/sting/ssum/sleepedf_chronos_result_outrandom"
# in_random : "/mnt/sting/ssum/sleepedf_chronos_result"
log_path: "/mnt/sting/ssum/sleepedf_chronos_result"

16
config/test.yaml Normal file
View File

@@ -0,0 +1,16 @@
log_path: ./temp/log
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
target_user: "00"
queue_size: 5
num_shot: 1
model_paths:
- ollama:url:rose.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11444/gpt-oss:20b
update_size: 1
vocab_size: 200064

View File

@@ -1,208 +1,39 @@
import os
import re
import json
import tiktoken
from langchain_ollama import ChatOllama
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from core.logger import Logger
from core.model import AsyncModelPool
class Agent:
def __init__(
self,
name,
model_pool,
log_path,
model_pool: AsyncModelPool,
logger: Logger,
):
self.name = name
self.model_pool = model_pool
self.log_path = log_path
self.root_log_path = log_path
self.agent_log_path = os.path.join(log_path, name)
os.makedirs(self.agent_log_path, exist_ok=True)
self.logger = logger
self.system_message = None
self.long_term_memory = []
self.short_term_memory = []
self.volatile_memory = []
def set_system_message(self, system_message: str):
self.system_message = system_message
log = f"[SYSTEM]\n{system_message}\n\n"
log_filename = "llm_log/system_prompt.txt"
self.logger.log(log, log_filename, print_log=False)
self.total_input_tokens = 0
self.total_output_tokens = 0
self.total_tokens = 0
self.total_calls = 0
def log(self, message, local=True):
path = os.path.join(self.root_log_path, "log.txt")
with open(path, "a", encoding="utf-8") as f:
message_type = "UNKNOWN"
if isinstance(message, SystemMessage):
message_type = "SYSTEM"
if isinstance(message, HumanMessage):
message_type = "HUMAN"
if isinstance(message, AIMessage):
message_type = "AI"
content = message.content.strip()
name = self.name
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
if local:
local_path = os.path.join(self.agent_log_path, "log.txt")
with open(local_path, "a", encoding="utf-8") as f:
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
def log_tokens(self, messages, response):
input_tokens = 0
for msg in messages:
msg_tokens = self.count_tokens(msg.content)
input_tokens += msg_tokens
self.total_tokens += msg_tokens
self.total_input_tokens += msg_tokens
output_tokens = self.count_tokens(response.content)
self.total_tokens += output_tokens
self.total_output_tokens += output_tokens
self.total_calls += 1
path = os.path.join(self.agent_log_path, "tokens.txt")
with open(path, "a", encoding="utf-8") as f:
f.write(f"Input tokens: {input_tokens}\n")
f.write(f"Output tokens: {output_tokens}\n")
f.write(f"Total input tokens: {self.total_input_tokens}\n")
f.write(f"Total output tokens: {self.total_output_tokens}\n")
f.write(f"Total tokens: {self.total_tokens}\n")
f.write(f"Total calls: {self.total_calls}\n")
f.write("\n")
def count_tokens(self, text, model="gpt-3.5-turbo"):
enc = tiktoken.encoding_for_model(model)
return len(enc.encode(text))
def update_memory(self):
self.long_term_memory.extend(self.short_term_memory)
self.clean_short_term_memory()
self.clean_volatile_memory()
def clean_short_term_memory(self):
self.short_term_memory = []
def clean_volatile_memory(self):
self.volatile_memory = []
def clean_long_term_memory(self):
self.long_term_memory = []
def clean_json_text(self, text):
text = text.strip()
text = text.replace("", "'").replace("", "'")
text = text.replace("", "'").replace("", "'")
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text)
text = re.sub(r",\s*}", "}", text)
text = re.sub(r",\s*]", "]", text)
text = "".join(ch for ch in text if ch.isprintable())
text = text.replace("][", ",")
return text
def safe_parse_json(self, text):
text = text.strip()
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
elif not text.endswith("}"):
text += "}"
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
print("[!] JSON parse failed")
return None
def safe_parse_json_list(self, text):
text = text.strip()
match = re.search(r"\[.*\]", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
elif not text.endswith("]"):
text += "]"
match = re.search(r"\[.*\]", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
print("[!] JSON parse failed")
return None
async def validate_response(self, response, fields, volatile=False):
if (
not response
or not isinstance(response, dict)
or not all(field in response for field in fields)
):
print("[!] The JSON failed to be parsed. Trying again.")
content = (
"Failed to parse the JSON from the previous response. Please try again."
)
response = await self.invoke(content, volatile=volatile)
response = self.safe_parse_json(response)
if (
not response
or not isinstance(response, dict)
or not all(field in response for field in fields)
):
print("[!] Retry failed.")
return None
return response
def get_last_response(self):
if len(self.long_term_memory) >= 2:
last_msg = self.long_term_memory[-1]
if isinstance(last_msg, AIMessage):
return self.safe_parse_json(last_msg.content)
return None
def set_system_message(self, content, local=True):
system_message = SystemMessage(content=content)
self.log(system_message, local)
self.long_term_memory.append(system_message)
async def invoke(self, content, volatile=False, local=True):
messages = self.long_term_memory.copy()
if volatile:
messages.extend(self.volatile_memory)
else:
messages.extend(self.short_term_memory)
messages.append(HumanMessage(content=content))
async def solve(self, content: str, sample_idx: int, example_idx: int):
messages = []
if self.system_message:
messages.append({"role": "system", "content": self.system_message})
user_message = {"role": "user", "content": content}
messages.append(user_message)
try:
response = await self.model_pool.invoke(messages)
self.log_tokens(messages, response)
if volatile:
self.volatile_memory.extend([HumanMessage(content=content), response])
else:
self.short_term_memory.extend([HumanMessage(content=content), response])
local_ = not volatile and local
self.log(HumanMessage(content=content), local=local_)
self.log(response, local=local_)
return response.content.strip()
except Exception as e: # pylint: disable=broad-exception-caught
print(f"[Error] Error occurred while invoking LLM: {e}")
response, logprobs = await self.model_pool.invoke(messages)
log_filename = f"llm_log/sample_{sample_idx}/example_{example_idx}.txt"
self.logger.log(
f"[USER]\n{content}\n\n[RESPONSE]\n{response}\n",
log_filename,
print_log=False,
)
return response, logprobs
except Exception as e:
self.logger.log(f"[Agent] invoke failed: {e}")
return None, None

View File

@@ -1,200 +1,81 @@
import os
import json
import datasets
import numpy as np
from glob import glob
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .embedding_index import EmbeddingIndex
from glob import glob
from typing import Optional, List
class DataLoader:
def __init__(
self,
data_path,
user_id,
selection_criteria="out_random",
num_examples=1,
embedding_index: Optional["EmbeddingIndex"] = None,
self,
data_path: str,
target_user: str,
shuffle: bool = False,
seed: int = 0,
):
self.is_valid = False
self.embedding_index = embedding_index
self.data_path = data_path
if not os.path.exists(os.path.join(data_path, "info.json")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
return
if selection_criteria in ["out_similar", "in_similar"] and embedding_index is None:
print(f"[WARNING] {selection_criteria} requires embedding_index, falling back to random")
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
self.seed = seed
self.task_metadata = self.load_task_metadata(data_path)
self.user_metadata = self.load_user_metadata(data_path)
self.target_dataset = self.load_target_dataset(data_path, target_user, shuffle)
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
self.example_dataset = datasets.Dataset.from_list([])
users = glob(os.path.join(data_path, "*"))
users = [path.split("/")[-1] for path in users]
if "info.json" in users:
users.remove("info.json")
for user in users:
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
all_users = glob(os.path.join(data_path, "*"))
all_users = [os.path.basename(p) for p in all_users if os.path.isdir(p)]
self.source_users = [u for u in all_users if u != target_user]
self.source_dataset = self.load_source_dataset(data_path, self.source_users)
self.classes = list(self.task_metadata["class"].keys())
self.test_dataset = self.test_dataset.shuffle(seed=0)
self.example_dataset = self.example_dataset.shuffle(seed=0)
def load_task_metadata(self, data_path: str):
task_metadata_path = os.path.join(data_path, "task_metadata.json")
with open(task_metadata_path, "r", encoding="utf-8") as f:
return json.load(f)
self.user_id = user_id
self.selection_criteria = selection_criteria
self.num_examples = num_examples
def load_user_metadata(self, data_path: str):
user_metadata_path = os.path.join(data_path, "user_metadata.json")
with open(user_metadata_path, "r", encoding="utf-8") as f:
return json.load(f)
self.classes = sorted(list(self.metadata["class"].keys()))
# Build lookup index for fast example retrieval: (user_id, idx) -> dataset_index
self._example_lookup = {}
for i, example in enumerate(self.example_dataset):
key = (str(example["user_id"]), int(example["idx"]))
self._example_lookup[key] = i
if selection_criteria in ["out_similar", "in_similar"]:
self.selected_examples = None
else:
self.selected_examples = self.sample_examples()
if self.selected_examples is None:
return
def load_target_dataset(
self, data_path: str, target_user: str, shuffle: bool = False
):
target_dataset_path = os.path.join(data_path, target_user)
target_dataset = datasets.load_from_disk(target_dataset_path)
if shuffle:
return target_dataset.shuffle(seed=self.seed)
return target_dataset
self.is_valid = True
def load_source_dataset(self, data_path: str, source_users: List[str]):
source_dataset = datasets.Dataset.from_list([])
for user in source_users:
user_dataset = datasets.load_from_disk(os.path.join(data_path, user))
source_dataset = datasets.concatenate_datasets(
[source_dataset, user_dataset]
)
source_dataset = source_dataset.shuffle(seed=self.seed)
return source_dataset
def __len__(self):
return len(self.test_dataset)
return len(self.target_dataset)
def __getitem__(self, idx):
sample = self.test_dataset[idx]
if self.selection_criteria in ["out_similar", "in_similar"]:
examples = self.sample_similar_examples(sample)
else:
examples = self.selected_examples
return sample, examples
def __getitem__(self, idx: int):
return self.target_dataset[idx]
def __iter__(self):
for sample in self.test_dataset:
if self.selection_criteria in ["out_similar", "in_similar"]:
examples = self.sample_similar_examples(sample)
else:
examples = self.selected_examples
yield sample, examples
def sample_similar_examples(self, sample):
"""
Sample examples based on embedding similarity to the given sample.
For each class, finds the most similar example from the example dataset
using Chronos-2 embeddings.
Args:
sample: The test sample to find similar examples for
Returns:
HuggingFace Dataset containing similar examples (one per class)
"""
if self.embedding_index is None:
return self.sample_examples()
query_embedding = self.embedding_index.get_embedding_by_key(
user_id=str(sample["user_id"]),
session_id=str(sample["session_id"]),
idx=int(sample["idx"]),
)
if query_embedding is None:
print(f"[WARNING] No embedding found for sample {sample['user_id']}/{sample['session_id']}/{sample['idx']}")
return self.sample_examples()
if self.selection_criteria == "out_similar":
exclude_user = str(self.user_id)
include_user = None
else:
exclude_user = None
include_user = str(self.user_id)
similar_per_class = self.embedding_index.find_similar_per_class(
query_embedding=query_embedding,
classes=self.classes,
k_per_class=self.num_examples,
exclude_user=exclude_user,
include_user=include_user,
filter_session="1",
)
example_list = []
for cls, similar_samples in similar_per_class.items():
for global_idx, similarity, metadata in similar_samples:
example = self._find_example_by_metadata(metadata)
if example is not None:
example_list.append(example)
if len(example_list) == 0:
print(f"[WARNING] No similar examples found, falling back to random")
return self.sample_examples()
return datasets.Dataset.from_list(example_list)
def _find_example_by_metadata(self, metadata):
"""Find an example in the example_dataset by its metadata using O(1) lookup."""
user_id = str(metadata["user_id"])
idx = int(metadata["idx"])
key = (user_id, idx)
if key not in self._example_lookup:
return None
dataset_idx = self._example_lookup[key]
example = self.example_dataset[dataset_idx]
return {
"user_id": example["user_id"],
"session_id": example["session_id"],
"idx": example["idx"],
"label": example["label"],
"features": example["features"],
"data": example.get("data", {}),
}
yield from self.target_dataset
def sample_examples(self):
example_dataset = datasets.Dataset.from_list([])
if self.selection_criteria == "out_random":
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] != user_id)
for c in self.classes:
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
if len(class_dataset) < self.num_examples:
return None
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
elif self.selection_criteria == "in_random":
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] == user_id)
for c in self.classes:
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
if len(class_dataset) < self.num_examples:
return None
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
return example_dataset
def get_source_dataset(self):
return self.source_dataset
def get_metadata(self):
return self.metadata
def get_task_metadata(self):
return self.task_metadata
def get_sensor_info(self):
return self.metadata["feature"]
def get_user_metadata(self, user: Optional[str] = None):
if user is None:
return self.user_metadata
return self.user_metadata[user]
def get_task_info(self):
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
classes_info = "\n".join(classes_info)
task_info += f"**Classes**:\n{classes_info}"
return task_info
def get_source_users(self):
return self.source_users
def get_classes_info(self):
classes_info = [k for k in self.metadata["class"].keys()]
return classes_info
def get_classes(self):
return self.classes

View File

@@ -1,281 +0,0 @@
"""
Embedding Index for Similarity-based Example Selection
This module provides functionality to:
1. Load pre-computed Chronos-2 embeddings
2. Build an index for fast nearest neighbor search
3. Find similar examples based on embedding distance
The embedding index enables ICL (In-Context Learning) example selection
based on semantic similarity rather than random sampling.
Similarity Strategies:
- out_similar: Find similar examples from OTHER users (cross-user transfer)
- in_similar: Find similar examples from SAME user (personalization)
Usage:
index = EmbeddingIndex(embedding_path)
similar_indices = index.find_similar(query_embedding, k=5, exclude_user="01")
Author: NMSL Research Team
Date: 2026-01-16
"""
import os
import numpy as np
from typing import List, Tuple, Optional, Dict, Any
from datasets import load_from_disk, Dataset
class EmbeddingIndex:
"""
Index for fast similarity search over pre-computed embeddings.
Uses cosine similarity for finding nearest neighbors in embedding space.
Supports filtering by user_id and session_id for controlled experiments.
Attributes:
embeddings: numpy array of shape (num_samples, embedding_dim)
user_ids: list of user identifiers for each sample
session_ids: list of session identifiers for each sample
labels: list of sleep stage labels for each sample
indices: list of original indices in the dataset
"""
def __init__(self, embedding_path: str):
"""
Initialize the embedding index from a HuggingFace dataset.
Args:
embedding_path: Path to directory containing saved embeddings dataset
(output from gen_plot.py extract command)
"""
if not os.path.isdir(embedding_path):
raise FileNotFoundError(
f"Embedding directory not found: {embedding_path}. "
"Run 'python gen_plot.py extract' first to generate embeddings."
)
print(f"[EmbeddingIndex] Loading embeddings from: {embedding_path}")
self.dataset = load_from_disk(embedding_path)
# Extract arrays for fast access
self.embeddings = np.array(self.dataset["embedding"], dtype=np.float32)
self.user_ids = np.array([str(uid) for uid in self.dataset["user_id"]])
self.session_ids = np.array([str(sid) for sid in self.dataset["session_id"]])
self.labels = np.array([str(label) for label in self.dataset["label"]])
self.indices = np.array(self.dataset["idx"])
# Normalize embeddings for cosine similarity (dot product of unit vectors)
self.embeddings_normalized = self._normalize(self.embeddings)
# Build lookup indices for fast filtering
self._build_lookup_indices()
print(f"[EmbeddingIndex] Loaded {len(self.embeddings)} samples")
print(f"[EmbeddingIndex] Embedding dimension: {self.embeddings.shape[1]}")
print(f"[EmbeddingIndex] Unique users: {len(np.unique(self.user_ids))}")
print(f"[EmbeddingIndex] Unique sessions: {len(np.unique(self.session_ids))}")
def _normalize(self, vectors: np.ndarray) -> np.ndarray:
"""L2 normalize vectors for cosine similarity computation."""
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms = np.where(norms == 0, 1, norms) # Avoid division by zero
return vectors / norms
def _build_lookup_indices(self):
"""Build dictionaries for fast filtering by user/session/label."""
self.user_to_indices: Dict[str, np.ndarray] = {}
self.session_to_indices: Dict[str, np.ndarray] = {}
self.label_to_indices: Dict[str, np.ndarray] = {}
for user_id in np.unique(self.user_ids):
self.user_to_indices[user_id] = np.where(self.user_ids == user_id)[0]
for session_id in np.unique(self.session_ids):
self.session_to_indices[session_id] = np.where(self.session_ids == session_id)[0]
for label in np.unique(self.labels):
self.label_to_indices[label] = np.where(self.labels == label)[0]
def get_embedding_by_key(
self,
user_id: str,
session_id: str,
idx: int
) -> Optional[np.ndarray]:
"""
Get embedding for a specific sample identified by (user_id, session_id, idx).
Args:
user_id: User identifier
session_id: Session identifier (1 or 2)
idx: Sample index within the session
Returns:
Embedding vector or None if not found
"""
mask = (
(self.user_ids == str(user_id)) &
(self.session_ids == str(session_id)) &
(self.indices == idx)
)
matches = np.where(mask)[0]
if len(matches) == 0:
return None
return self.embeddings[matches[0]]
def cosine_similarity(
self,
query: np.ndarray,
candidates: np.ndarray
) -> np.ndarray:
"""
Compute cosine similarity between query and candidate vectors.
Args:
query: Query vector of shape (embedding_dim,)
candidates: Candidate matrix of shape (num_candidates, embedding_dim)
Returns:
Similarity scores of shape (num_candidates,)
"""
query_normalized = query / (np.linalg.norm(query) + 1e-8)
return candidates @ query_normalized
def find_similar(
self,
query_embedding: np.ndarray,
k: int = 5,
exclude_user: Optional[str] = None,
include_user: Optional[str] = None,
filter_session: Optional[str] = None,
filter_label: Optional[str] = None,
) -> List[Tuple[int, float, Dict[str, Any]]]:
"""
Find k most similar samples to the query embedding.
Args:
query_embedding: Query vector of shape (embedding_dim,)
k: Number of nearest neighbors to return
exclude_user: User ID to exclude from search (for out_similar)
include_user: User ID to include only (for in_similar)
filter_session: Only search in this session (e.g., "1" for train set)
filter_label: Only return samples with this label
Returns:
List of (index, similarity, metadata) tuples sorted by similarity (descending)
metadata contains: user_id, session_id, idx, label
"""
# Build candidate mask based on filters
candidate_mask = np.ones(len(self.embeddings), dtype=bool)
if exclude_user is not None:
candidate_mask &= (self.user_ids != str(exclude_user))
if include_user is not None:
candidate_mask &= (self.user_ids == str(include_user))
if filter_session is not None:
candidate_mask &= (self.session_ids == str(filter_session))
if filter_label is not None:
candidate_mask &= (self.labels == str(filter_label))
candidate_indices = np.where(candidate_mask)[0]
if len(candidate_indices) == 0:
return []
# Compute similarities
candidate_embeddings = self.embeddings_normalized[candidate_indices]
similarities = self.cosine_similarity(query_embedding, candidate_embeddings)
# Get top-k
k = min(k, len(similarities))
top_k_local = np.argsort(similarities)[::-1][:k]
results = []
for local_idx in top_k_local:
global_idx = candidate_indices[local_idx]
sim = similarities[local_idx]
metadata = {
"user_id": self.user_ids[global_idx],
"session_id": self.session_ids[global_idx],
"idx": int(self.indices[global_idx]),
"label": self.labels[global_idx],
}
results.append((global_idx, float(sim), metadata))
return results
def find_similar_per_class(
self,
query_embedding: np.ndarray,
classes: List[str],
k_per_class: int = 1,
exclude_user: Optional[str] = None,
include_user: Optional[str] = None,
filter_session: Optional[str] = "1",
) -> Dict[str, List[Tuple[int, float, Dict[str, Any]]]]:
"""
Find k most similar samples for each class.
This ensures balanced class representation in ICL examples,
similar to the original random sampling approach.
Args:
query_embedding: Query vector
classes: List of class labels to search
k_per_class: Number of examples per class
exclude_user: User to exclude (out_similar mode)
include_user: User to include only (in_similar mode)
filter_session: Session to search in (default: "1" = train set)
Returns:
Dictionary mapping class label to list of similar samples
"""
results = {}
for cls in classes:
similar = self.find_similar(
query_embedding=query_embedding,
k=k_per_class,
exclude_user=exclude_user,
include_user=include_user,
filter_session=filter_session,
filter_label=cls,
)
results[cls] = similar
return results
def get_sample_metadata(self, global_idx: int) -> Dict[str, Any]:
"""Get metadata for a sample by its global index."""
return {
"user_id": self.user_ids[global_idx],
"session_id": self.session_ids[global_idx],
"idx": int(self.indices[global_idx]),
"label": self.labels[global_idx],
"embedding": self.embeddings[global_idx],
}
def create_embedding_index(embedding_path: str) -> Optional[EmbeddingIndex]:
"""
Factory function to create an EmbeddingIndex, returning None if path doesn't exist.
Args:
embedding_path: Path to embeddings directory
Returns:
EmbeddingIndex instance or None
"""
if not os.path.isdir(embedding_path):
print(f"[WARNING] Embedding path not found: {embedding_path}")
return None
try:
return EmbeddingIndex(embedding_path)
except Exception as e:
print(f"[ERROR] Failed to load embeddings: {e}")
return None

37
core/example_queue.py Normal file
View File

@@ -0,0 +1,37 @@
from typing import List, Dict, Any
from core.logger import Logger
class ExampleQueue:
def __init__(self, queue_size: int, logger: Logger):
self.queue_size = queue_size
self.logger = logger
self.queue: List[Dict[str, Any]] = []
def __iter__(self):
yield from self.queue
def update(self, results: List[Dict[str, Any]], new_examples: List[Dict[str, Any]]):
n_replace = len(new_examples)
if n_replace == 0:
return
if results:
ranked = sorted(range(len(results)), key=lambda i: results[i]["score"])
drop_indices = sorted(ranked[:n_replace], reverse=True)
for idx in drop_indices:
self.logger.log(
f"[Queue] Dropping index {idx} (score={results[idx]['score']:.4f})"
)
self.queue.pop(idx)
slots = self.queue_size - len(self.queue)
if slots < len(new_examples):
self.logger.log(
f"[Queue] Capping: {len(new_examples)} new but only {slots} slot(s) free"
)
new_examples = new_examples[:slots]
self.queue.extend(new_examples)
self.logger.log(f"[Queue] Updated, queue size = {len(self.queue)}")

29
core/json_utils.py Normal file
View File

@@ -0,0 +1,29 @@
import re
import json
def clean_json_text(text: str):
text = text.strip()
text = text.replace("\u2018", "'").replace("\u2019", "'") # smart single quotes
text = text.replace("\u201c", '"').replace("\u201d", '"') # smart double quotes
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text) # escape stray backslashes
text = re.sub(r",\s*}", "}", text) # trailing comma in object
text = re.sub(r",\s*]", "]", text) # trailing comma in array
text = "".join(ch for ch in text if ch.isprintable()) # strip control chars
text = text.replace("][", ",") # merge adjacent arrays
return text
def safe_parse_json(text: str):
if not text:
return None
text = text.strip()
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match and not text.endswith("}"):
match = re.search(r"\{.*\}", text + "}", re.DOTALL)
if match:
try:
return json.loads(clean_json_text(match.group(0)))
except json.JSONDecodeError as e:
return None
return None

58
core/logger.py Normal file
View File

@@ -0,0 +1,58 @@
import os
import yaml
from datetime import datetime
from sklearn.metrics import accuracy_score, f1_score
class Logger:
def __init__(self, log_path: str):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.log_path = f"{log_path}_{timestamp}"
os.makedirs(self.log_path, exist_ok=True)
self.answers = []
def log(self, message: str, filename: str = "log.txt", print_log: bool = True):
if print_log:
print(message)
log_file_path = os.path.join(self.log_path, filename)
base_dir = os.path.dirname(log_file_path)
os.makedirs(base_dir, exist_ok=True)
with open(log_file_path, "a", encoding="utf-8") as f:
f.write(message + "\n")
def log_config(self, config: dict):
message = yaml.dump(config, default_flow_style=False, sort_keys=False).strip()
self.log(message, "config.yaml", print_log=False)
def log_result(self, idx: int, answer: str, ground_truth: str):
self.answers.append(
{
"idx": idx,
"answer": answer,
"ground_truth": ground_truth,
}
)
self.log(
f"[RESULT] {idx}: answer={answer}, ground_truth={ground_truth}",
"result.txt",
)
def report(self, elapsed_seconds: float = None):
if not self.answers:
self.log("[REPORT] No valid answers recorded", "report.txt")
return
answers = [a["answer"] for a in self.answers]
ground_truths = [a["ground_truth"] for a in self.answers]
accuracy = accuracy_score(ground_truths, answers)
f1 = f1_score(ground_truths, answers, average="macro")
n = len(self.answers)
time_str = ""
if elapsed_seconds is not None:
m, s = divmod(int(elapsed_seconds), 60)
h, m = divmod(m, 60)
time_str = f", time={h}h{m:02d}m{s:02d}s"
self.log(
f"[REPORT] accuracy={accuracy:.4f}, f1={f1:.4f}, n={n}{time_str}",
"report.txt",
)

View File

@@ -1,95 +1,133 @@
import asyncio
import requests
import numpy as np
from langchain_ollama import ChatOllama
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage
def load_models(models):
model_pool = AsyncModelPool()
for model in models:
model_pool.add_model(Model(model))
model_pool.init_models()
return model_pool
from typing import Dict, List
class Model:
def __init__(self, model, temperature=0.7):
if model.startswith("ollama:"):
model = model.replace("ollama:", "")
if "url:" in model:
model = model.replace("url:", "")
base_url = model.split("/")[0]
if not base_url.startswith("http"):
base_url = "http://" + base_url
model_type = model.split("/")[1]
self.model = ChatOllama(
model=model_type,
base_url=base_url,
temperature=temperature,
num_ctx=12000,
)
else:
self.model = ChatOllama(
model=model.replace("ollama:", ""),
temperature=temperature,
num_ctx=12000,
)
else:
self.model = init_chat_model(
model=model,
temperature=temperature,
)
def __init__(
self,
model_path: str,
temperature: float = 0.0,
num_ctx: int = 131072,
max_tokens: int = -1,
logprobs: bool = True,
top_logprobs: int = 20,
top_p: float = 1.0,
top_k: int = 0,
stream: bool = False,
think: str = "low",
):
self.backend = None
def invoke(self, messages):
if model_path.startswith("ollama:"):
raw = model_path.split("url:")[1]
self.backend = "ollama"
self.base_url = f"http://{raw.split('/')[0]}"
self.model_name = raw.split("/")[1]
self.temperature = temperature
self.num_ctx = num_ctx
self.max_tokens = max_tokens
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.top_p = top_p
self.top_k = top_k
self.stream = stream
self.think = think
else:
raise ValueError(f"Unknown model prefix: {model_path}")
def invoke(self, messages: List[Dict[str, str]]):
try:
response = self.model.invoke(messages)
return response
return self._invoke_ollama(messages)
except Exception as e:
print(f"[Error] Error occurred while invoking LLM: {e}")
return e
print(f"[Error] invoke failed: {e}")
return "", np.array([])
def _invoke_ollama(self, messages: List[Dict[str, str]]):
resp = requests.post(
f"{self.base_url}/api/chat",
json={
"messages": messages,
"model": self.model_name,
"temperature": self.temperature,
"num_ctx": self.num_ctx,
"num_predict": self.max_tokens,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"top_p": self.top_p,
"top_k": self.top_k,
"stream": self.stream,
"think": self.think,
},
timeout=300,
)
resp.raise_for_status()
data = resp.json()
response = data["message"]["content"]
logprobs = []
for lp in data.get("logprobs", []):
logprobs.append([tp["logprob"] for tp in lp.get("top_logprobs", [])])
return response, np.array(logprobs) if logprobs else np.array([])
class AsyncModel:
def __init__(self, model):
def __init__(self, model: Model):
self.model = model
async def invoke(self, content):
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
async def invoke(self, messages: List[Dict[str, str]]):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None,
lambda: self.model.invoke(content),
lambda: self.model.invoke(messages),
)
return response
class AsyncModelPool:
def __init__(self):
self.models = []
self._available_models = None
self._model_semaphore = None
self._queue: asyncio.Queue = asyncio.Queue()
def add_model(self, model):
self.models.append(model)
def add_model(self, model: Model):
self._queue.put_nowait(AsyncModel(model))
def init_models(self):
# Initialize the queue and semaphore in the current event loop
self._available_models = asyncio.Queue()
for model in self.models:
async_model = AsyncModel(model)
self._available_models.put_nowait(async_model)
self._model_semaphore = asyncio.Semaphore(len(self.models))
def warmup(self):
for model in self.models:
model.invoke([HumanMessage(content="Hello world!")])
async def invoke(self, content):
if self._available_models is None:
raise RuntimeError("Model pool not initialized. Call init_models() first.")
async_model = await self._available_models.get()
async def invoke(self, messages: List[Dict[str, str]]):
async_model = await self._queue.get()
try:
response = await async_model.invoke(content)
return response
return await async_model.invoke(messages)
finally:
self._available_models.put_nowait(async_model)
self._queue.put_nowait(async_model)
def load_models(
model_paths: List[str],
temperature: float = 0.0,
num_ctx: int = 131072,
max_tokens: int = -1,
logprobs: bool = True,
top_logprobs: int = 20,
top_p: float = 1.0,
top_k: int = 0,
stream: bool = False,
think: str = "low",
):
pool = AsyncModelPool()
for path in model_paths:
pool.add_model(
Model(
path,
temperature=temperature,
num_ctx=num_ctx,
max_tokens=max_tokens,
logprobs=logprobs,
top_logprobs=top_logprobs,
top_p=top_p,
top_k=top_k,
stream=stream,
think=think,
)
)
print(f"[ModelPool] Loaded a model from {path}")
return pool

66
core/prompt.py Normal file
View File

@@ -0,0 +1,66 @@
from typing import Any, Dict, List
def gen_system_message(metadata: Dict[str, Any]):
task_info = metadata["task"]
classes_info = [f" - {k}: {v}" for k, v in metadata["class"].items()]
classes_info = "\n".join(classes_info)
data_info = metadata["data"]
feature_info = metadata["feature"]
system_message = (
f"You are an assistant who interprets sensor data to solve a task.\n\n"
f"1. Task:\n"
f"{task_info}\n\n"
f"2. Classes:\n"
f"{classes_info}\n\n"
f"3. Data:\n"
f"{data_info}\n\n"
f"4. Features:\n"
f"{feature_info}\n\n"
"Your goal is to analyze the sensor data and "
"provide a reasoned answer for the task.\n"
"Do not output analysis."
)
return system_message
def gen_task_message(
sample: Dict[str, Any],
example_set: List[Dict[str, Any]],
):
def format_feature(value: Any):
if isinstance(value, float):
if abs(value) >= 1e4 or (abs(value) < 1e-2 and value != 0):
return f"{value:.2e}"
return f"{value:.2f}"
return str(value)
example_info = ""
for cls, examples in example_set.items():
for i, example in enumerate(examples):
example_info += f"Example {i+1} of {cls}:\n"
for k, v in example["features"].items():
example_info += f" - {k}: {format_feature(v)}\n"
example_info += "\n"
example_info = example_info.strip()
test_info = "Current sensor data:\n"
for k, v in sample["features"].items():
test_info += f" - {k}: {format_feature(v)}\n"
test_info = test_info.strip()
classes = list(example_set.keys())
task_message = (
"You have a few labeled examples of sensor data:\n"
f"{example_info}\n\n"
f"And you have the current sensor data:\n"
f"{test_info}\n\n"
f"Please provide your answer among {classes} "
"and the reasoning for your answer.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<reasoning for the answer>",\n'
f' "ANSWER": "<answer among {classes}>"\n'
"}"
)
return task_message

52
core/recruiter.py Normal file
View File

@@ -0,0 +1,52 @@
import random
import numpy as np
from datasets import Dataset
from typing import Any, Dict, List
from core.logger import Logger
class Recruiter:
def __init__(
self,
source_dataset: Dataset,
source_users: List[str],
classes: List[str],
num_shot: int,
logger: Logger,
):
self.source_dataset = source_dataset
self.num_shot = num_shot
self.source_users = source_users
self.classes = classes
self.logger = logger
def recruit(self, num_example_set: int):
# place holder for the recruiter strategy
self.logger.log(f"[Recruiter] Recruiting {num_example_set} example set(s)")
# randomly select examples from the user dataset
list_example_sets = []
for _ in range(num_example_set):
# randomly select user
user = random.choice(self.source_users)
user_dataset = self.source_dataset.filter(lambda x: x["user_id"] == user)
example_set = {}
for cls in self.classes:
cls_dataset = user_dataset.filter(lambda x: x["label"] == cls)
if len(cls_dataset) < self.num_shot:
raise ValueError(
f"Not enough examples for class {cls} in user {user}"
)
random_index = np.random.choice(
len(cls_dataset), self.num_shot, replace=False
)
example_set[cls] = cls_dataset.select(random_index)
self.logger.log(
f"[Recruiter] Recruited {self.num_shot} example(s) from user {user}"
)
list_example_sets.append(example_set)
return list_example_sets
def update_strategy(self, results: List[Dict[str, Any]]):
# TODO: implement adaptive recruitment strategy based on results
pass

27
core/scores.py Normal file
View File

@@ -0,0 +1,27 @@
import math
import numpy as np
def self_certainty(logprobs: np.ndarray, vocab_size: int) -> float:
"""Returns -inf if logprobs is None or empty."""
if logprobs is None or logprobs.size == 0:
return float("-inf")
# probability mass in top-k
probs = np.exp(logprobs)
k = probs.shape[-1]
topk_sum = probs.sum(axis=-1)
# remaining probability mass
tail_mass = 1.0 - topk_sum
tail_mass = np.clip(tail_mass, 1e-12, None)
# uniform distribution over remaining tokens
tail_prob = tail_mass / (vocab_size - k)
# sum log probabilities
logprob_sum_topk = logprobs.sum(axis=-1)
logprob_sum_tail = (vocab_size - k) * np.log(tail_prob)
logprob_sum = logprob_sum_topk + logprob_sum_tail
# self-certainty score
score = (-1.0 / vocab_size) * logprob_sum - math.log(vocab_size)
return float(np.mean(score))

View File

@@ -1,181 +0,0 @@
import json
import copy
import os
from .agent import Agent
class SensingAgent(Agent):
def __init__(
self,
name,
model_pool,
task_info,
classes_info,
sensor_info,
sample,
examples,
log_path,
):
super().__init__(
name=name,
model_pool=model_pool,
log_path=log_path,
)
self.task_info = task_info
self.classes_info = classes_info
self.sensor_info = sensor_info
self.sample = sample
self.examples = examples
self.init_system_message()
def init_system_message(self):
content = (
f"You are {self.name} agent that interprets sensor data to solve a task.\n"
"You have the following information about the task:\n"
f"{self.task_info}\n\n"
"You have the following information about the sensor data:\n"
f"{self.sensor_info}\n\n"
"Your goal is to analyze the features and "
"provide a reasoned answer using your knowledge."
)
self.set_system_message(content)
def gen_feature_info(self):
feature_info = f"{self.name} features:\n"
if len(self.examples) > 0:
feature_info += f"{self.gen_example_info()}\n\n"
feature_info += "**Current sample features**:\n"
for k, v in self.sample["features"].items():
feature_info += f" - {k}: {self.format_feature(v)}\n"
feature_info = feature_info.strip()
return feature_info
def gen_example_info(self):
example_info = (
"**Examples**\n"
"Sensor values might not always align with your inherent "
"knowledge due to differences in data collection or processing. "
"So, we included a few labeled examples to help your interpretation:\n"
)
for example in self.examples:
example_info += f"*Example of {example['label']}*:\n"
for k, v in example["features"].items():
example_info += f" - {k}: {self.format_feature(v)}\n"
example_info += "\n"
example_info = example_info.strip()
return example_info
def format_feature(self, value):
if isinstance(value, float):
if abs(value) >= 1e4 or abs(value) < 1e-2:
return f"{value:.2e}"
return f"{value:.2f}"
return value
def log_summary(self, message, print_log=True):
path = os.path.join(self.log_path, "summary.txt")
with open(path, "a", encoding="utf-8") as f:
f.write(f"{message}\n")
if print_log:
print(message)
async def solve(self, sample, examples, ground_truth):
self.sample = sample
self.examples = examples
feature_info = self.gen_feature_info()
content = (
f"You have received sensor features from {self.name} modality:\n"
f"{feature_info}\n\n"
f"Please provide your answer for the task among {self.classes_info} "
"and the reasoning for your answer.\n"
"Note that the sensor features might be wrong due to the data collection or processing.\n"
"You can evaluate the quality of the features by checking the examples you have.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Reasoning for the answer>",\n'
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "ANSWER"]
)
self.clean_short_term_memory()
self.clean_long_term_memory()
answer = parsed_response["ANSWER"]
self.log_summary(f"Answer: {answer} (Ground truth: {ground_truth})", print_log=True)
return parsed_response
async def interpret(self):
feature_info = self.gen_feature_info()
content = (
f"You have received sensor features from {self.name} modality:\n"
f"{feature_info}\n\n"
f"Please provide your answer for the task among {self.classes_info} "
"and the reasoning for your answer.\n"
"Note that the sensor features might be wrong due to the data collection or processing.\n"
"You can evaluate the quality of the features by checking the examples you have.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Reasoning for the answer>",\n'
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "ANSWER"]
)
return parsed_response
async def evaluate(self, target_name, initial_response):
initial_response_info = json.dumps(initial_response, indent=2)
content = (
f"Other agent, <{target_name}> provided the following answer for the same task:\n"
f"{initial_response_info}\n\n"
"Please evaluate the given reasoning and answer based on your judgement. "
"You may either support with it or disagree.\n"
"If you agree, explain why the reasoning and answer are valid. "
"If you disagree, explain why the reasoning or answer may be flawed, "
f"and provide constructive feedback on how <{target_name}> can improve its response.\n"
"Respond in the following strict JSON format:\n"
"{\n"
f' "EVALUATION": "<Evaluation to <{target_name}>"\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content, volatile=True)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["EVALUATION"], volatile=True
)
self.clean_volatile_memory()
return parsed_response
async def reflect(self, evaluations):
evaluations_info = json.dumps(evaluations, indent=2)
content = (
f"Other agents have evaluated your answer for the same task:\n"
f"{evaluations_info}\n\n"
"Please reflect on the evaluations and provide a refined answer for the same task.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Reasoning for the answer>",\n'
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "ANSWER"]
)
return parsed_response

24
core/vote.py Normal file
View File

@@ -0,0 +1,24 @@
from typing import Any, Dict, List
from collections import Counter
def borda_vote(results: List[Dict[str, Any]], borda_p: float = 1.0):
parsed = [
{"answer": r["answer"], "score": r["score"]}
for r in results
if r.get("answer") is not None
]
if not parsed:
return None, {}
parsed.sort(key=lambda x: x["score"], reverse=True)
n = len(parsed)
tally: Counter = Counter()
for rank, entry in enumerate(parsed, start=1):
votes = int((n - rank + 1) ** borda_p)
tally[entry["answer"]] += votes
winner, _ = tally.most_common(1)[0]
return winner, dict(tally.most_common())

View File

@@ -0,0 +1,78 @@
"""
Generate per-(method, user) config files.
Usage:
python experiments/gen_configs.py \
--dataset sleepedf \
--data_path /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new \
--users 00,01,02
"""
import os
import yaml
from fire import Fire
METHODS = [
"random_fixed_single",
"random_dynamic_single",
"random_fixed_sc",
"random_dynamic_sc",
"random_dynamic_borda",
"random_fixed_borda",
"ours",
]
DEFAULTS = {
"queue_size": 5,
"num_shot": 1,
"update_size": 3,
"borda_p": 1.0,
"vocab_size": 200064,
"model_paths": [
"ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b",
"ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b",
],
}
def main(dataset: str, data_path: str, users: str):
"""
Args:
dataset: Dataset name (e.g. "sleepedf").
data_path: Absolute path to the dataset.
users: Comma-separated user IDs (e.g. "00,01,02").
"""
users = [u.strip() for u in users.split(",")]
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
for method in METHODS:
for user in users:
config_dir = os.path.join(project_root, "config", method, dataset)
os.makedirs(config_dir, exist_ok=True)
config = dict(DEFAULTS)
config["data_path"] = data_path
config["target_user"] = user
config["log_path"] = (
f"/mnt/sting/hjyoon/projects/llm_personalization/logs"
f"/{method}/{dataset}/{user}"
)
config_path = os.path.join(config_dir, f"user{user}.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
print(f" {config_path}")
print(f"\nGenerated {len(METHODS) * len(users)} configs.")
if __name__ == "__main__":
Fire(main)

75
experiments/run.sh Executable file
View File

@@ -0,0 +1,75 @@
#!/bin/bash
#
# Generic experiment runner. Finds configs under config/<method>/<dataset>/
# and runs the corresponding Python script.
#
# Usage:
# bash experiments/run.sh <method> <dataset> # one method, all users
# bash experiments/run.sh <method> <dataset> <config> # one config file
# bash experiments/run.sh all <dataset> # all methods
#
# Examples:
# bash experiments/run.sh ours sleepedf
# bash experiments/run.sh random_fixed_sc sleepedf
# bash experiments/run.sh random_fixed_sc sleepedf user00.yaml
# bash experiments/run.sh all sleepedf
#
set -e
METHOD="${1:?Usage: $0 <method|all> <dataset> [config_file]}"
DATASET="${2:?Usage: $0 <method|all> <dataset> [config_file]}"
CONFIG_FILE="$3"
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
CONFIG_ROOT="$PROJECT_ROOT/config"
run_config() {
local method="$1"
local config="$2"
echo "============================================"
echo " Method: $method | Config: $(basename "$config")"
echo "============================================"
if [ "$method" = "ours" ]; then
python "$PROJECT_ROOT/run.py" --config_path "$config"
else
python "$PROJECT_ROOT/baselines/${method}.py" --config_path "$config"
fi
echo ""
}
if [ "$METHOD" != "all" ] && [ -n "$CONFIG_FILE" ]; then
# Single config
config="$CONFIG_ROOT/$METHOD/$DATASET/$CONFIG_FILE"
if [ ! -f "$config" ]; then
echo "Error: config not found: $config"
exit 1
fi
run_config "$METHOD" "$config"
elif [ "$METHOD" != "all" ]; then
# All configs for one method
config_dir="$CONFIG_ROOT/$METHOD/$DATASET"
if [ ! -d "$config_dir" ]; then
echo "Error: config directory not found: $config_dir"
exit 1
fi
for config in "$config_dir"/*.yaml; do
run_config "$METHOD" "$config"
done
else
# All methods
for method_dir in "$CONFIG_ROOT"/*/; do
method="$(basename "$method_dir")"
dataset_dir="$method_dir$DATASET"
[ -d "$dataset_dir" ] || continue
for config in "$dataset_dir"/*.yaml; do
run_config "$method" "$config"
done
done
fi
echo "All experiments complete."

View File

@@ -0,0 +1,324 @@
import os
import json
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
from fire import Fire
from datasets import Dataset
from datetime import timedelta
warnings.filterwarnings("ignore")
GLOBEM_PATH = "/mnt/sting/hjyoon/projects/llm_personalization/dataset/GLOBEM/physionet.org/files/globem/1.1"
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/GLOBEM"
PHASES = {
"INS-W": [1, 2, 3, 4],
}
FEATURE_TYPES = ["f_loc", "f_screen", "f_slp", "f_steps"]
FEATURE_COLUMNS = {
"f_loc": [
"phone_locations_barnett_avgflightdur", "phone_locations_barnett_avgflightlen",
"phone_locations_barnett_circdnrtn", "phone_locations_barnett_disttravelled",
"phone_locations_barnett_hometime", "phone_locations_barnett_maxdiam",
"phone_locations_barnett_maxhomedist", "phone_locations_barnett_probpause",
"phone_locations_barnett_rog", "phone_locations_barnett_siglocentropy",
"phone_locations_barnett_siglocsvisited", "phone_locations_barnett_stdflightdur",
"phone_locations_barnett_stdflightlen", "phone_locations_barnett_wkenddayrtn",
"phone_locations_doryab_avglengthstayatclusters", "phone_locations_doryab_avgspeed",
"phone_locations_doryab_homelabel", "phone_locations_doryab_locationentropy",
"phone_locations_doryab_locationvariance", "phone_locations_doryab_loglocationvariance",
"phone_locations_doryab_maxlengthstayatclusters", "phone_locations_doryab_minlengthstayatclusters",
"phone_locations_doryab_movingtostaticratio", "phone_locations_doryab_normalizedlocationentropy",
"phone_locations_doryab_numberlocationtransitions", "phone_locations_doryab_numberofsignificantplaces",
"phone_locations_doryab_outlierstimepercent", "phone_locations_doryab_radiusgyration",
"phone_locations_doryab_stdlengthstayatclusters", "phone_locations_doryab_timeathome",
"phone_locations_doryab_timeattop1location", "phone_locations_doryab_timeattop2location",
"phone_locations_doryab_timeattop3location", "phone_locations_doryab_totaldistance",
"phone_locations_doryab_varspeed",
"phone_locations_locmap_duration_in_locmap_study", "phone_locations_locmap_percent_in_locmap_study",
"phone_locations_locmap_duration_in_locmap_exercise", "phone_locations_locmap_percent_in_locmap_exercise",
"phone_locations_locmap_duration_in_locmap_greens", "phone_locations_locmap_percent_in_locmap_greens",
],
"f_screen": [
"phone_screen_rapids_countepisodeunlock", "phone_screen_rapids_sumdurationunlock",
"phone_screen_rapids_maxdurationunlock", "phone_screen_rapids_mindurationunlock",
"phone_screen_rapids_avgdurationunlock", "phone_screen_rapids_stddurationunlock",
"phone_screen_rapids_firstuseafter00unlock",
"phone_screen_rapids_countepisodeunlock_locmap_exercise", "phone_screen_rapids_sumdurationunlock_locmap_exercise",
"phone_screen_rapids_maxdurationunlock_locmap_exercise", "phone_screen_rapids_mindurationunlock_locmap_exercise",
"phone_screen_rapids_avgdurationunlock_locmap_exercise", "phone_screen_rapids_stddurationunlock_locmap_exercise",
"phone_screen_rapids_firstuseafter00unlock_locmap_exercise",
"phone_screen_rapids_countepisodeunlock_locmap_greens", "phone_screen_rapids_sumdurationunlock_locmap_greens",
"phone_screen_rapids_maxdurationunlock_locmap_greens", "phone_screen_rapids_mindurationunlock_locmap_greens",
"phone_screen_rapids_avgdurationunlock_locmap_greens", "phone_screen_rapids_stddurationunlock_locmap_greens",
"phone_screen_rapids_firstuseafter00unlock_locmap_greens",
"phone_screen_rapids_countepisodeunlock_locmap_living", "phone_screen_rapids_sumdurationunlock_locmap_living",
"phone_screen_rapids_maxdurationunlock_locmap_living", "phone_screen_rapids_mindurationunlock_locmap_living",
"phone_screen_rapids_avgdurationunlock_locmap_living", "phone_screen_rapids_stddurationunlock_locmap_living",
"phone_screen_rapids_firstuseafter00unlock_locmap_living",
"phone_screen_rapids_countepisodeunlock_locmap_study", "phone_screen_rapids_sumdurationunlock_locmap_study",
"phone_screen_rapids_maxdurationunlock_locmap_study", "phone_screen_rapids_mindurationunlock_locmap_study",
"phone_screen_rapids_avgdurationunlock_locmap_study", "phone_screen_rapids_stddurationunlock_locmap_study",
"phone_screen_rapids_firstuseafter00unlock_locmap_study",
"phone_screen_rapids_countepisodeunlock_locmap_home", "phone_screen_rapids_sumdurationunlock_locmap_home",
"phone_screen_rapids_maxdurationunlock_locmap_home", "phone_screen_rapids_mindurationunlock_locmap_home",
"phone_screen_rapids_avgdurationunlock_locmap_home", "phone_screen_rapids_stddurationunlock_locmap_home",
"phone_screen_rapids_firstuseafter00unlock_locmap_home",
],
"f_slp": [
"fitbit_sleep_summary_rapids_sumdurationafterwakeupmain", "fitbit_sleep_summary_rapids_sumdurationasleepmain",
"fitbit_sleep_summary_rapids_sumdurationawakemain", "fitbit_sleep_summary_rapids_sumdurationtofallasleepmain",
"fitbit_sleep_summary_rapids_sumdurationinbedmain", "fitbit_sleep_summary_rapids_avgefficiencymain",
"fitbit_sleep_summary_rapids_avgdurationafterwakeupmain", "fitbit_sleep_summary_rapids_avgdurationasleepmain",
"fitbit_sleep_summary_rapids_avgdurationawakemain", "fitbit_sleep_summary_rapids_avgdurationtofallasleepmain",
"fitbit_sleep_summary_rapids_avgdurationinbedmain", "fitbit_sleep_summary_rapids_countepisodemain",
"fitbit_sleep_summary_rapids_firstbedtimemain", "fitbit_sleep_summary_rapids_lastbedtimemain",
"fitbit_sleep_summary_rapids_firstwaketimemain", "fitbit_sleep_summary_rapids_lastwaketimemain",
"fitbit_sleep_intraday_rapids_avgdurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_avgdurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_maxdurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_maxdurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_sumdurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_sumdurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_countepisodeasleepunifiedmain", "fitbit_sleep_intraday_rapids_countepisodeawakeunifiedmain",
"fitbit_sleep_intraday_rapids_stddurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_stddurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_mindurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_mindurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_mediandurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_mediandurationawakeunifiedmain",
"fitbit_sleep_intraday_rapids_ratiocountasleepunifiedwithinmain", "fitbit_sleep_intraday_rapids_ratiocountawakeunifiedwithinmain",
"fitbit_sleep_intraday_rapids_ratiodurationasleepunifiedwithinmain", "fitbit_sleep_intraday_rapids_ratiodurationawakeunifiedwithinmain",
],
"f_steps": [
"fitbit_steps_summary_rapids_maxsumsteps", "fitbit_steps_summary_rapids_minsumsteps",
"fitbit_steps_summary_rapids_avgsumsteps", "fitbit_steps_summary_rapids_mediansumsteps",
"fitbit_steps_summary_rapids_stdsumsteps",
"fitbit_steps_intraday_rapids_sumsteps", "fitbit_steps_intraday_rapids_maxsteps",
"fitbit_steps_intraday_rapids_minsteps", "fitbit_steps_intraday_rapids_avgsteps",
"fitbit_steps_intraday_rapids_stdsteps",
"fitbit_steps_intraday_rapids_countepisodesedentarybout", "fitbit_steps_intraday_rapids_sumdurationsedentarybout",
"fitbit_steps_intraday_rapids_maxdurationsedentarybout", "fitbit_steps_intraday_rapids_mindurationsedentarybout",
"fitbit_steps_intraday_rapids_avgdurationsedentarybout", "fitbit_steps_intraday_rapids_stddurationsedentarybout",
"fitbit_steps_intraday_rapids_countepisodeactivebout", "fitbit_steps_intraday_rapids_sumdurationactivebout",
"fitbit_steps_intraday_rapids_maxdurationactivebout", "fitbit_steps_intraday_rapids_mindurationactivebout",
"fitbit_steps_intraday_rapids_avgdurationactivebout", "fitbit_steps_intraday_rapids_stddurationactivebout",
],
}
TIME_SEGMENT = "allday"
WINDOW_DAYS = 28
LABEL_COL = "dep"
PREDICTION_TARGET = "dep_weekly"
CLASS_DICT = {True: "depressed", False: "not_depressed"}
PRE_SURVEY_COLS = [
"UCLA_10items_PRE", "SocialFit_PRE",
"2waySSS_receiving_emotional_PRE", "2waySSS_giving_emotional_PRE",
"2waySSS_giving_instrumental_PRE", "2waySSS_receiving_instrumental_PRE",
"ERQ_reappraisal_PRE", "ERQ_suppression_PRE",
"BRS_PRE", "CHIPS_PRE", "PSS_10items_PRE", "STAIS_PRE", "MAAS_7items_PRE",
"CESD_9items_PRE", "CESD_10items_PRE",
"BFI10_extroversion_PRE", "BFI10_agreeableness_PRE",
"BFI10_conscientiousness_PRE", "BFI10_neuroticism_PRE", "BFI10_openness_PRE",
]
def get_feature_col_names():
"""Build the list of full column names: f_type:feature_name:time_segment"""
cols = []
for ft in FEATURE_TYPES:
for feat in FEATURE_COLUMNS[ft]:
cols.append(f"{ft}:{feat}:{TIME_SEGMENT}")
return cols
def store_task_metadata(path):
task_metadata = {
"task": (
'Classify the user\'s depression status: ["depressed", "not_depressed"], '
"based on passive sensing data collected from a smartphone and a wearable fitness tracker."
),
"class": {
"depressed": "The user shows depressive symptoms based on self-reported weekly survey responses.",
"not_depressed": "The user does not show depressive symptoms based on self-reported weekly survey responses.",
},
"data": (
"Data were collected over a three-month study period from college students at a university. "
"Participants carried a smartphone and wore a Fitbit fitness tracker 24x7. "
"Passive sensing data includes GPS location, phone screen usage, Fitbit sleep, and Fitbit physical activity. "
"Features were extracted using the RAPIDS toolkit and computed daily over multiple time segments. "
"Each sample represents the last day of a 28-day observation window preceding a depression label date. "
"Each feature is named using the format 'sensor_type:feature_name:time_segment'."
),
"feature": (
"Location features (f_loc) include GPS-based metrics such as home time, distance travelled, "
"radius of gyration, location entropy, number of significant places, and time spent at various locations. "
"Phone usage features (f_screen) include unlock episode counts, durations, and location-specific phone usage patterns. "
"Sleep features (f_slp) include Fitbit-derived metrics such as sleep duration, efficiency, "
"time to fall asleep, bedtime/waketime, and intraday sleep/wake episode statistics. "
"Physical activity features (f_steps) include step counts, sedentary bout statistics, and active bout statistics. "
"All features use the 'allday' time segment (24 hours from midnight to midnight)."
),
}
with open(path, "w", encoding="utf-8") as f:
json.dump(task_metadata, f, indent=2)
def store_user_metadata(globem_path, out_path):
"""Build per-user metadata from platform.csv and pre.csv across all phases."""
user_metadata = {}
for institution, phases in PHASES.items():
for phase in phases:
ds_dir = os.path.join(globem_path, f"{institution}_{phase}")
platform_path = os.path.join(ds_dir, "ParticipantsInfoData", "platform.csv")
pre_path = os.path.join(ds_dir, "SurveyData", "pre.csv")
if not os.path.exists(platform_path):
print(f" Skipping {institution}_{phase}: platform.csv not found")
continue
df_platform = pd.read_csv(platform_path)
if "Unnamed: 0" in df_platform.columns:
df_platform = df_platform.drop(columns=["Unnamed: 0"])
df_platform = df_platform.set_index("pid")
df_pre = None
if os.path.exists(pre_path):
df_pre = pd.read_csv(pre_path)
if "Unnamed: 0" in df_pre.columns:
df_pre = df_pre.drop(columns=["Unnamed: 0"])
df_pre = df_pre.set_index("pid")
for pid in df_platform.index:
user_key = f"{pid}#{institution}_{phase}"
platform = str(df_platform.loc[pid, "platform"]).split(";")[0]
meta = {"platform": platform, "phase": phase}
if df_pre is not None and pid in df_pre.index:
row = df_pre.loc[pid]
for col in PRE_SURVEY_COLS:
if col in row.index:
val = row[col]
meta[col] = round(float(val), 4) if pd.notna(val) else None
user_metadata[user_key] = meta
with open(out_path, "w", encoding="utf-8") as f:
json.dump(user_metadata, f, indent=2)
def process_single_dataset(globem_path, institution, phase, feature_cols):
"""Process one dataset (institution + phase) and return per-user sample lists."""
ds_dir = os.path.join(globem_path, f"{institution}_{phase}")
rapids_path = os.path.join(ds_dir, "FeatureData", "rapids.csv")
label_path = os.path.join(ds_dir, "SurveyData", "dep_weekly.csv")
platform_path = os.path.join(ds_dir, "ParticipantsInfoData", "platform.csv")
for p in [rapids_path, label_path, platform_path]:
if not os.path.exists(p):
print(f" Skipping {institution}_{phase}: {os.path.basename(p)} not found")
return {}
print(f" Loading {institution}_{phase}...")
df_rapids = pd.read_csv(rapids_path, low_memory=False)
if "Unnamed: 0" in df_rapids.columns:
df_rapids = df_rapids.drop(columns=["Unnamed: 0"])
df_rapids["date"] = pd.to_datetime(df_rapids["date"])
df_label = pd.read_csv(label_path)
if "Unnamed: 0" in df_label.columns:
df_label = df_label.drop(columns=["Unnamed: 0"])
df_label["date"] = pd.to_datetime(df_label["date"])
df_label = df_label.dropna(subset=[LABEL_COL])
df_label = df_label.drop_duplicates(["pid", "date"], keep="last")
available_cols = [c for c in feature_cols if c in df_rapids.columns]
if len(available_cols) == 0:
print(f" Skipping {institution}_{phase}: no matching feature columns")
return {}
user_data = {}
phase_str = str(phase)
pids_few_response = df_label.groupby("pid")["date"].count()
valid_pids = set(pids_few_response[pids_few_response >= 2].index)
for _, row in tqdm(df_label.iterrows(), total=len(df_label),
desc=f" {institution}_{phase}", leave=False):
pid = row["pid"]
if pid not in valid_pids:
continue
date_end = row["date"]
date_start = date_end - timedelta(days=WINDOW_DAYS - 1)
label_val = row[LABEL_COL]
df_window = df_rapids[
(df_rapids["pid"] == pid)
& (df_rapids["date"] >= date_start)
& (df_rapids["date"] <= date_end)
]
if df_window.empty:
continue
last_day = df_window.sort_values("date").iloc[-1]
features = {}
for col in available_cols:
val = last_day[col]
if pd.notna(val):
features[col] = float(val)
else:
features[col] = None
user_key = f"{pid}#{institution}_{phase}"
if user_key not in user_data:
user_data[user_key] = []
user_data[user_key].append(dict(
user_id=user_key,
session_id=phase_str,
idx=len(user_data[user_key]),
date=str(date_end.date()),
label=CLASS_DICT[label_val],
features=features,
))
return user_data
def run(path=GLOBEM_PATH, out_dir=OUT_DIR):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
feature_cols = get_feature_col_names()
print(f"Using {len(feature_cols)} feature columns ({TIME_SEGMENT} time segment)")
print("Saving task metadata...")
store_task_metadata(os.path.join(out_dir, "task_metadata.json"))
print("Saving user metadata...")
store_user_metadata(path, os.path.join(out_dir, "user_metadata.json"))
all_user_data = {}
for institution, phases in PHASES.items():
for phase in phases:
user_data = process_single_dataset(path, institution, phase, feature_cols)
for user_key, samples in user_data.items():
all_user_data.setdefault(user_key, []).extend(samples)
print(f"\nSaving datasets for {len(all_user_data)} users...")
total_samples = 0
for user_key, data in sorted(all_user_data.items()):
safe_name = user_key.replace("#", "_")
user_dir = os.path.join(out_dir, safe_name)
dataset = Dataset.from_list(data)
dataset.save_to_disk(user_dir)
total_samples += len(data)
print(f"\nDone. {total_samples} total samples from {len(all_user_data)} users saved to {out_dir}")
if __name__ == "__main__":
Fire(run)

View File

@@ -4,6 +4,7 @@ import json
import warnings
import numpy as np
import neurokit2 as nk
import pandas as pd
from tqdm import tqdm
from fire import Fire
@@ -21,7 +22,7 @@ warnings.simplefilter("ignore", NeuroKitWarning)
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
SLEEPEDF_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/sleep-cassette/"
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new"
EPOCH_SEC_SIZE = 30
SAMPLING_RATE = 100
@@ -45,8 +46,8 @@ ann2label = {
}
def store_info(info_path):
info = {
def store_task_metadata(path):
task_metadata = {
"task": 'Classify the user\'s sleep stage: ["W", "N1", "N2", "N3", "REM"], based on physiological signals collected from wearable sensors.',
"class": {
"W": "Wakefulness. This includes periods before sleep onset or after final awakening, and short awakenings during the night.",
@@ -72,8 +73,21 @@ def store_info(info_path):
"Ratio features such as delta/theta, theta/alpha, alpha/beta, and (delta+theta)/(alpha+beta) were also included."
),
}
with open(info_path, "w", encoding="utf-8") as f:
json.dump(info, f, indent=2)
with open(path, "w", encoding="utf-8") as f:
json.dump(task_metadata, f, indent=2)
def store_user_metadata(src_path, out_path):
# load xls file
df = pd.read_excel(src_path)
user_metadata = {}
for _, row in df.iterrows():
user_id = str(row["subject"])
age = int(row["age"])
sex = int(row["sex (F=1)"])-1 # 0: female, 1: male
user_metadata[user_id] = {"age": age, "sex": sex}
with open(out_path, "w", encoding="utf-8") as f:
json.dump(user_metadata, f, indent=2)
def lowpass_filter(data, cutoff=50, fs=1000, order=4):
@@ -197,122 +211,7 @@ def process_by_mod(modality, data):
features[f"{modality}_theta/alpha_ratio"] = theta_alpha_ratio
features[f"{modality}_alpha/beta_ratio"] = alpha_beta_ratio
features[f"{modality}_(delta+theta)/(alpha+beta)_ratio"] = slow_fast_ratio
# elif "EOG" in modality:
# eog_cleaned = data
# try:
# eog_cleaned = nk.eog_clean(data, sr)
# except IndexError as e:
# print(f"Error processing EOG data for {modality}: {e}")
# return None
# eog_mean = np.mean(eog_cleaned)
# eog_std = np.std(eog_cleaned)
# eog_var = np.var(eog_cleaned)
# features[f"{modality}_mean"] = eog_mean
# features[f"{modality}_std"] = eog_std
# features[f"{modality}_variance"] = eog_var
# dynamic_range = np.max(eog_cleaned) - np.min(eog_cleaned)
# features[f"{modality}_dynamic_range"] = dynamic_range
# peaks = signal.find_peaks(eog_cleaned - eog_mean, height=3 * eog_std)[0]
# features[f"{modality}_num_peaks"] = len(peaks)
# zero_crossings = np.where(np.diff(np.sign(eog_cleaned - eog_mean)))[0]
# features[f"{modality}_num_zero_crossings"] = len(zero_crossings)
# differences = eog_cleaned[1:] - eog_cleaned[:-1]
# difference_variance = np.var(differences)
# features[f"{modality}_difference_variance"] = difference_variance
# features[f"{modality}_num_large_eye_movements"] = count_large_eye_movements(
# eog_cleaned, sr, amp_thresh=120, time_thresh=1.5
# )
# eog_large_movement_removed = remove_large_eye_movements(
# eog_cleaned, fs=sr, amp_thresh=120, time_thresh=1.5, pad=0.75
# )
# differences = eog_large_movement_removed[1:] - eog_large_movement_removed[:-1]
# difference_variance = np.var(differences)
# features[f"{modality}_difference_variance_without_large_movements"] = (
# difference_variance
# )
# freqs, psd = signal.welch(eog_cleaned, fs=sr, nperseg=sr * 2)
# total_idx = np.logical_and(freqs >= 0.5, freqs <= 30)
# total_power = np.trapezoid(psd[total_idx], freqs[total_idx])
# slow_idx = np.logical_and(freqs >= 0.5, freqs <= 2)
# rapid_idx = np.logical_and(freqs >= 2, freqs <= 5)
# slow_power = np.trapezoid(psd[slow_idx], freqs[slow_idx])
# rapid_power = np.trapezoid(psd[rapid_idx], freqs[rapid_idx])
# slow_power_ratio = slow_power / total_power if total_power > 0 else 0
# rapid_power_ratio = rapid_power / total_power if total_power > 0 else 0
# features[f"{modality}_slow_movement_power_ratio"] = slow_power_ratio
# features[f"{modality}_rapid_movement_power_ratio"] = rapid_power_ratio
# elif "Resp" in modality:
# rsp_signals = data
# try:
# rsp_signals, _ = nk.rsp_process(data, sampling_rate=sr, method="biosppy")
# except IndexError as e:
# print(f"Error processing respiration data for {modality}: {e}")
# return None
# clean = rsp_signals["RSP_Clean"]
# phase = rsp_signals["RSP_Phase"]
# rate = rsp_signals["RSP_Rate"]
# amplitude = rsp_signals["RSP_Amplitude"]
# peaks = np.where(rsp_signals["RSP_Peaks"] == 1)[0]
# troughs = np.where(rsp_signals["RSP_Troughs"] == 1)[0]
# inhale_durations = []
# for t in troughs:
# next_peaks = peaks[peaks > t]
# if len(next_peaks) == 0:
# continue
# inhale_durations.append((next_peaks[0] - t) / sr)
# inhale_durations = np.array(inhale_durations)
# exhale_durations = []
# for p in peaks:
# next_troughs = troughs[troughs > p]
# if len(next_troughs) == 0:
# continue
# exhale_durations.append((next_troughs[0] - p) / sr)
# exhale_durations = np.array(exhale_durations)
# features[f"{modality}_inhale_duration_mean"] = np.mean(inhale_durations)
# features[f"{modality}_inhale_duration_std"] = np.std(inhale_durations)
# features[f"{modality}_exhale_duration_mean"] = np.mean(exhale_durations)
# features[f"{modality}_exhale_duration_std"] = np.std(exhale_durations)
# features[f"{modality}_inhale_exhale_ratio"] = (
# np.mean(inhale_durations) / np.mean(exhale_durations)
# if np.mean(exhale_durations) > 0
# else np.nan
# )
# features[f"{modality}_stretch"] = np.max(clean) - np.min(clean)
# inhale_mask = phase == 1
# features[f"{modality}_inspiration_volume"] = np.trapezoid(
# amplitude[inhale_mask], dx=1 / sr
# )
# features[f"{modality}_respiration_rate"] = np.mean(rate)
# resp_durations = np.diff(troughs) / sr
# features[f"{modality}_respiration_duration"] = np.mean(resp_durations)
# elif "EMG" in modality:
# emg_mean = np.mean(data)
# emg_std = np.std(data)
# features[f"{modality}_mean"] = emg_mean
# features[f"{modality}_std"] = emg_std
# features[f"{modality}_dynamic_range"] = np.max(data) - np.min(data)
# features[f"{modality}_absolute_integral"] = np.sum(np.abs(data)) / sr
# features[f"{modality}_median"] = np.median(data)
# features[f"{modality}_10th_percentile"] = np.percentile(data, 10)
# features[f"{modality}_90th_percentile"] = np.percentile(data, 90)
# peaks, _ = signal.find_peaks(data, height=3 * emg_std)
# peak_values = data[peaks]
# features[f"{modality}_num_peaks"] = len(peaks)
# features[f"{modality}_peak_amplitude_mean"] = (
# np.mean(peak_values) if len(peak_values) > 0 else 0
# )
# features[f"{modality}_peak_amplitude_std"] = (
# np.std(peak_values) if len(peak_values) > 0 else 0
# )
# features[f"{modality}_peak_amplitude_sum"] = (
# np.sum(peak_values) if len(peak_values) > 0 else 0
# )
# features[f"{modality}_peak_amplitude_norm_sum"] = (
# np.sum(peak_values) / np.sum(np.abs(data))
# if np.sum(np.abs(data)) > 0
# else 0
# )
return features
@@ -450,9 +349,17 @@ def run(path=SLEEPEDF_PATH, out_dir=OUT_DIR, num_examples=1, num_workers=32, see
np.random.seed(seed)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
info_path = os.path.join(out_dir, "info.json")
store_info(info_path)
print(f"Saved info to {info_path}")
task_metadata_out_path = os.path.join(out_dir, "task_metadata.json")
store_task_metadata(task_metadata_out_path)
print(f"Saved info to {task_metadata_out_path}")
user_metadata_out_path = os.path.join(out_dir, "user_metadata.json")
user_metadata_src_path = os.path.join(path, "..", "SC-subjects.xls")
store_user_metadata(user_metadata_src_path, user_metadata_out_path)
print(f"Saved info to {user_metadata_out_path}")
psg_file_paths = glob(os.path.join(path, "*PSG.edf"))
ann_file_paths = glob(os.path.join(path, "*Hypnogram.edf"))
psg_file_paths.sort()
@@ -466,16 +373,20 @@ def run(path=SLEEPEDF_PATH, out_dir=OUT_DIR, num_examples=1, num_workers=32, see
elif basename.startswith("SC41"):
filtered_2013_file_paths.append(file_path)
user_data = {}
with Pool(processes=num_workers) as pool:
for data in pool.imap_unordered(preprocess, filtered_2013_file_paths):
if len(data) == 0:
continue
user_id = data[0]["user_id"]
session_id = data[0]["session_id"]
dataset = Dataset.from_list(data)
test_dir = os.path.join(out_dir, f"{user_id}", f"{session_id}")
dataset.save_to_disk(test_dir)
print(f"Saved dataset to {test_dir}")
user_data.setdefault(user_id, []).extend(data)
for user_id, data in user_data.items():
dataset = Dataset.from_list(data)
test_dir = os.path.join(out_dir, f"{user_id}")
dataset.save_to_disk(test_dir)
print(f"Saved dataset to {test_dir} ({len(data)} samples, "
f"{len(set(d['session_id'] for d in data))} session(s))")
if __name__ == "__main__":

370
run.py
View File

@@ -1,294 +1,114 @@
import os
import re
import asyncio
import yaml
import json
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any, Optional
import time
from glob import glob
import yaml
from fire import Fire
from core.model import load_models
from core.data_loader import DataLoader
from core.sensing_agent import SensingAgent
from core.embedding_index import EmbeddingIndex, create_embedding_index
_EMBEDDING_INDEX: Optional[EmbeddingIndex] = None
def init_embedding_index(embedding_path: Optional[str]) -> Optional[EmbeddingIndex]:
"""Initialize global embedding index for similarity-based selection."""
global _EMBEDDING_INDEX
if embedding_path is None:
_EMBEDDING_INDEX = None
return None
if _EMBEDDING_INDEX is not None:
return _EMBEDDING_INDEX
_EMBEDDING_INDEX = create_embedding_index(embedding_path)
return _EMBEDDING_INDEX
from core.example_queue import ExampleQueue
from core.model import load_models
from core.recruiter import Recruiter
from core.agent import Agent
from core.logger import Logger
from core.prompt import gen_system_message, gen_task_message
from core.json_utils import safe_parse_json
from core.scores import self_certainty
from core.vote import borda_vote
def load_user_data_sync(
data_path: str,
user: str,
seed: int,
log_path_base: str,
selection_criteria: str = "out_random",
num_examples: int = 1,
embedding_index: Optional[EmbeddingIndex] = None,
) -> List[Dict[str, Any]]:
"""Load data for a single user with specified selection criteria."""
print(f"[DATA LOADING] Starting: user={user}, seed={seed}, criteria={selection_criteria}")
# Set random seed for reproducibility
np.random.seed(seed)
data_loader = DataLoader(
data_path,
user,
selection_criteria=selection_criteria,
num_examples=num_examples,
embedding_index=embedding_index,
async def run(config_path: str):
print("[Main] Loading config")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
logger = Logger(config.get("log_path"))
logger.log_config(config)
logger.log("[Main] Loaded config")
logger.log("[Main] Loading data loader")
dataloader = DataLoader(config.get("data_path"), config.get("target_user"))
logger.log("[Main] Loaded data loader")
logger.log("[Main] Initializing example queue")
example_q = ExampleQueue(
queue_size=config.get("queue_size"),
logger=logger,
)
if not data_loader.is_valid:
print(f"[DATA LOADING] Skipping invalid user: {user}")
return []
tasks = []
idx = 0
dataset_size = len(data_loader)
print(f"[DATA LOADING] User {user} has {dataset_size} samples")
for sample, examples in data_loader:
if idx % 10 != 0:
idx += 1
continue
log_path = os.path.join(log_path_base, user, f"{idx:02d}", str(seed))
os.makedirs(log_path, exist_ok=True)
kwargs = {
"task_info": data_loader.get_task_info(),
"classes_info": data_loader.get_classes_info(),
"sensor_info": data_loader.get_sensor_info(),
"sample": sample,
"examples": examples,
"log_path": log_path,
"ground_truth": sample["label"],
}
tasks.append(kwargs)
idx += 1
print(f"[DATA LOADING] Completed user {user}, seed {seed}: {len(tasks)} tasks")
return tasks
recruiter = Recruiter(
source_dataset=dataloader.get_source_dataset(),
source_users=dataloader.get_source_users(),
num_shot=config.get("num_shot"),
classes=dataloader.get_classes(),
logger=logger,
)
list_examples = recruiter.recruit(config.get("queue_size"))
example_q.update([], list_examples)
logger.log("[Main] Initialized example queue")
logger.log("[Main] Loading model pool")
model_pool = load_models(config.get("model_paths"))
logger.log("[Main] Loaded model pool")
async def load_data_parallel(config: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Load data for all users in parallel with configurable selection criteria."""
print("[DATA LOADING] Starting parallel data loading...")
# Get configuration parameters
selection_criteria = config.get("selection_criteria", "out_random")
embedding_path = config.get("embedding_path", None)
num_examples = config.get("num_examples", 1)
# Initialize embedding index if needed for similarity-based selection
embedding_index = None
if selection_criteria in ["out_similar", "in_similar"]:
if embedding_path is None:
print(f"[WARNING] {selection_criteria} requires embedding_path in config")
print("[WARNING] Falling back to random selection")
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
else:
embedding_index = init_embedding_index(embedding_path)
if embedding_index is None:
print(f"[WARNING] Failed to load embeddings, falling back to random")
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
user_paths = glob(os.path.join(config["data_path"], "*"))
user_paths = [path for path in user_paths if os.path.isdir(path)]
users = [path.split("/")[-1] for path in user_paths]
print(f"[DATA LOADING] Found {len(users)} users: {users}")
print(f"[DATA LOADING] Selection criteria: {selection_criteria}")
print(f"[DATA LOADING] Num examples per class: {num_examples}")
max_workers = config.get("data_workers", 96)
print(f"[DATA LOADING] Using {max_workers} workers for data loading")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
loop = asyncio.get_event_loop()
futures = []
for seed in range(config["num_seeds"]):
for user in users:
future = loop.run_in_executor(
executor,
load_user_data_sync,
config["data_path"],
user,
seed,
config["log_path"],
selection_criteria,
num_examples,
embedding_index,
)
futures.append(future)
print(f"[DATA LOADING] Created {len(futures)} parallel data loading tasks")
results = await asyncio.gather(*futures, return_exceptions=True)
# Flatten results and filter out exceptions
all_tasks = []
for result in results:
if isinstance(result, Exception):
print(f"[DATA LOADING] Error loading data: {result}")
else:
all_tasks.extend(result)
logger.log("[Main] Initializing agent")
agent = Agent(
model_pool=model_pool,
logger=logger,
)
system_message = gen_system_message(metadata=dataloader.get_task_metadata())
agent.set_system_message(system_message)
logger.log("[Main] Initialized agent")
print(f"[DATA LOADING] Total tasks: {len(all_tasks)}")
return all_tasks
async def process_example_set(sample_idx, example_idx, example_set, sample):
logger.log(f"[Main] Processing {sample_idx} - {example_idx} (queue index)")
try:
task_message = gen_task_message(sample, example_set)
response, logprobs = await agent.solve(
task_message, sample_idx, example_idx
)
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
parsed = safe_parse_json(response)
answer = parsed.get("ANSWER") if parsed else None
logger.log(
f"[Main] Done {sample_idx} - {example_idx}: "
f"answer={answer}, score={score:.4f}"
)
return {"example_set": example_set, "answer": answer, "score": score}
except Exception as e:
logger.log(
f"[Main] Error {sample_idx} - {example_idx}: {e}",
filename="errors.txt",
)
return {
"example_set": example_set,
"answer": None,
"score": float("-inf"),
}
async def run_parallel(kwargs_list: List[Dict[str, Any]], model_pool, config: Dict[str, Any]) -> None:
"""Run classification tasks in parallel using the model pool."""
print(f"[EXECUTION] Starting {len(kwargs_list)} classification tasks...")
tasks = []
for kwargs in kwargs_list:
agent = SensingAgent(
name="EEG sensing",
model_pool=model_pool,
task_info=kwargs["task_info"],
classes_info=kwargs["classes_info"],
sensor_info=kwargs["sensor_info"],
sample=kwargs["sample"],
examples=kwargs["examples"],
log_path=kwargs["log_path"],
)
task = asyncio.create_task(agent.solve(kwargs["sample"], kwargs["examples"], kwargs["ground_truth"]))
tasks.append(task)
await asyncio.gather(*tasks)
print("[EXECUTION] All tasks completed")
def run(config_path: str) -> None:
"""
Main entry point for running experiments.
Args:
config_path: Path to YAML configuration file
"""
print(f"[MAIN] Loading config from: {config_path}")
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
print("=" * 60)
print("EXPERIMENT CONFIGURATION")
print("=" * 60)
print(f" Data path: {config.get('data_path', 'N/A')}")
print(f" Log path: {config.get('log_path', 'N/A')}")
print(f" Selection criteria: {config.get('selection_criteria', 'out_random')}")
print(f" Num examples: {config.get('num_examples', 1)}")
print(f" Num seeds: {config.get('num_seeds', 1)}")
print(f" Embedding path: {config.get('embedding_path', 'N/A')}")
print(f" Num models: {len(config.get('models', []))}")
print("=" * 60)
model_pool = load_models(config["models"])
kwargs_list = asyncio.run(load_data_parallel(config))
if len(kwargs_list) == 0:
print("[ERROR] No valid tasks to run. Check data paths and configuration.")
return
# Warmup models
print("[MAIN] Warming up models...")
model_pool.warmup()
# Run experiments
print("[MAIN] Starting experiments...")
logger.log("[Main] Starting main loop")
start_time = time.time()
asyncio.run(run_parallel(kwargs_list, model_pool, config))
elapsed = time.time() - start_time
print(f"[MAIN] Experiment completed in {elapsed:.2f} seconds")
print(f"[MAIN] Results saved to: {config['log_path']}")
for idx, sample in enumerate(dataloader):
logger.log(f"[Main] Processing sample {idx} / {len(dataloader)} (sample index)")
tasks = [
process_example_set(idx, example_idx, example_set, sample)
for example_idx, example_set in enumerate(example_q)
]
results = await asyncio.gather(*tasks)
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
ground_truth = sample["label"]
if winner is not None:
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
logger.log(f"[Vote] votes={{ {tally_str} }}")
logger.log_result(idx, winner, ground_truth)
else:
logger.log(f"[Vote] Sample {idx} | no valid answer parsed, skipping")
recruiter.update_strategy(results)
list_examples = recruiter.recruit(num_example_set=config.get("update_size"))
example_q.update(results, list_examples)
logger.report(elapsed_seconds=time.time() - start_time)
def run_comparison(
base_config_path: str,
criteria_list: str = "out_random,in_random,out_similar,in_similar",
embedding_path: str = None,
) -> None:
"""
Run experiments comparing multiple selection criteria.
Args:
base_config_path: Path to base YAML config file
criteria_list: Comma-separated list of selection criteria to compare
embedding_path: Path to embeddings (required for *_similar criteria)
Example:
python run.py compare config/sleepedf.yaml \\
--criteria_list="out_random,out_similar" \\
--embedding_path="./embeddings_full"
"""
base_config = yaml.load(open(base_config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
criteria = [c.strip() for c in criteria_list.split(",")]
print("=" * 60)
print("COMPARISON EXPERIMENT")
print("=" * 60)
print(f" Selection criteria to compare: {criteria}")
print(f" Embedding path: {embedding_path}")
print("=" * 60)
for criterion in criteria:
print(f"\n{'='*60}")
print(f"Running experiment: {criterion}")
print(f"{'='*60}")
config = base_config.copy()
config["selection_criteria"] = criterion
base_log_path = config["log_path"]
if not base_log_path.endswith(criterion):
config["log_path"] = f"{base_log_path}_{criterion}"
if criterion in ["out_similar", "in_similar"]:
if embedding_path:
config["embedding_path"] = embedding_path
else:
print(f"[WARNING] {criterion} requires --embedding_path argument")
continue
model_pool = load_models(config["models"])
kwargs_list = asyncio.run(load_data_parallel(config))
if len(kwargs_list) == 0:
print(f"[WARNING] No tasks for {criterion}, skipping")
continue
model_pool.warmup()
asyncio.run(run_parallel(kwargs_list, model_pool, config))
print(f"[DONE] Results for {criterion} saved to: {config['log_path']}")
print("\n" + "=" * 60)
print("ALL COMPARISON EXPERIMENTS COMPLETED")
print("=" * 60)
def main(config_path: str):
asyncio.run(run(config_path))
if __name__ == "__main__":
Fire({
"run": run,
"compare": run_comparison,
})
Fire(main)

File diff suppressed because it is too large Load Diff

View File

@@ -1,412 +0,0 @@
"""
Self-Consistency Results Analysis Script
This module provides analysis tools for Self-Consistency experiment results:
- Load and aggregate results from experiment directories
- Compute detailed statistics and metrics
- Generate visualizations and reports
- Compare results across different configurations
Usage:
# Analyze single experiment
python -m sc.analysis.analyze_sc_results analyze /path/to/results
# Compare multiple experiments
python -m sc.analysis.analyze_sc_results compare /path/to/exp1 /path/to/exp2
# Generate summary report
python -m sc.analysis.analyze_sc_results report /path/to/results --output report.md
Author: NMSL Research Team
Date: 2026-01-21
"""
import os
import json
import yaml
import numpy as np
import pandas as pd
from glob import glob
from typing import List, Dict, Any, Optional, Tuple
from collections import defaultdict
from fire import Fire
class SCResultsAnalyzer:
"""
Analyzer for Self-Consistency experiment results.
Loads experiment results and provides various analysis methods
including accuracy computation, consistency analysis, and
per-class/per-user breakdowns.
Attributes:
results_path: Path to experiment results directory
results: List of result dictionaries
config: Experiment configuration
stats: Computed statistics
"""
def __init__(self, results_path: str):
"""
Initialize analyzer with results directory.
Args:
results_path: Path to experiment results directory
Should contain all_results.json and statistics.json
"""
self.results_path = results_path
self.results = []
self.config = {}
self.stats = {}
self._load_results()
def _load_results(self) -> None:
"""Load results, config, and statistics from files."""
# Load all results
results_file = os.path.join(self.results_path, "all_results.json")
if os.path.exists(results_file):
with open(results_file, "r", encoding="utf-8") as f:
self.results = json.load(f)
print(f"[LOAD] Loaded {len(self.results)} results from {results_file}")
else:
print(f"[WARNING] Results file not found: {results_file}")
# Load config
config_file = os.path.join(self.results_path, "config.yaml")
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
self.config = yaml.safe_load(f)
print(f"[LOAD] Loaded config from {config_file}")
# Load pre-computed statistics
stats_file = os.path.join(self.results_path, "statistics.json")
if os.path.exists(stats_file):
with open(stats_file, "r", encoding="utf-8") as f:
self.stats = json.load(f)
print(f"[LOAD] Loaded statistics from {stats_file}")
def get_dataframe(self) -> pd.DataFrame:
"""
Convert results to pandas DataFrame.
Returns:
DataFrame with one row per result
"""
if not self.results:
return pd.DataFrame()
return pd.DataFrame(self.results)
def compute_accuracy(self) -> Dict[str, float]:
"""
Compute overall and per-class accuracy.
Returns:
Dictionary with 'overall' and per-class accuracies
"""
if not self.results:
return {"overall": 0.0}
df = self.get_dataframe()
# Overall accuracy
overall = df["is_correct"].mean()
# Per-class accuracy
accuracy = {"overall": overall}
for cls in df["ground_truth"].unique():
cls_df = df[df["ground_truth"] == cls]
accuracy[cls] = cls_df["is_correct"].mean()
return accuracy
def compute_consistency_analysis(self) -> Dict[str, Any]:
"""
Analyze relationship between consistency and accuracy.
Returns:
Dictionary with consistency statistics and correlations
"""
if not self.results:
return {}
df = self.get_dataframe()
# Consistency distribution
consistency_mean = df["consistency"].mean()
consistency_std = df["consistency"].std()
# High consistency accuracy
high_cons = df[df["consistency"] >= 0.8]
low_cons = df[df["consistency"] < 0.8]
high_cons_acc = high_cons["is_correct"].mean() if len(high_cons) > 0 else 0
low_cons_acc = low_cons["is_correct"].mean() if len(low_cons) > 0 else 0
# Consistency bins
bins = [0.0, 0.4, 0.6, 0.8, 1.0]
labels = ["0.0-0.4", "0.4-0.6", "0.6-0.8", "0.8-1.0"]
df["consistency_bin"] = pd.cut(df["consistency"], bins=bins, labels=labels)
bin_stats = {}
for label in labels:
bin_df = df[df["consistency_bin"] == label]
bin_stats[label] = {
"count": len(bin_df),
"accuracy": bin_df["is_correct"].mean() if len(bin_df) > 0 else 0,
}
return {
"consistency_mean": consistency_mean,
"consistency_std": consistency_std,
"high_consistency_count": len(high_cons),
"high_consistency_accuracy": high_cons_acc,
"low_consistency_count": len(low_cons),
"low_consistency_accuracy": low_cons_acc,
"bin_statistics": bin_stats,
}
def compute_per_user_accuracy(self) -> Dict[str, Dict[str, Any]]:
"""
Compute accuracy breakdown by user.
Returns:
Dictionary mapping user_id to accuracy metrics
"""
if not self.results:
return {}
df = self.get_dataframe()
user_stats = {}
for user_id in df["user_id"].unique():
user_df = df[df["user_id"] == user_id]
user_stats[user_id] = {
"count": len(user_df),
"accuracy": user_df["is_correct"].mean(),
"avg_consistency": user_df["consistency"].mean(),
"avg_confidence": user_df["confidence"].mean(),
}
return user_stats
def compute_confusion_matrix(self) -> Tuple[np.ndarray, List[str]]:
"""
Compute confusion matrix.
Returns:
Tuple of (confusion_matrix, class_labels)
"""
if not self.results:
return np.array([]), []
df = self.get_dataframe()
classes = sorted(df["ground_truth"].unique())
matrix = np.zeros((len(classes), len(classes)), dtype=int)
class_to_idx = {cls: i for i, cls in enumerate(classes)}
for _, row in df.iterrows():
gt_idx = class_to_idx[row["ground_truth"]]
pred = row["answer"]
if pred in class_to_idx:
pred_idx = class_to_idx[pred]
matrix[gt_idx, pred_idx] += 1
return matrix, classes
def generate_report(self) -> str:
"""
Generate comprehensive markdown report.
Returns:
Markdown formatted report string
"""
report = []
report.append("# Self-Consistency Experiment Results\n")
# Config summary
report.append("## Experiment Configuration\n")
if self.config:
report.append(f"- Selection Criteria: {self.config.get('selection_criteria', 'N/A')}")
report.append(f"- Num ICL Examples: {self.config.get('num_examples', 'N/A')}")
report.append(f"- Num SC Samples: {self.config.get('num_sc_samples', 'N/A')}")
report.append(f"- Temperature: {self.config.get('temperature', 'N/A')}")
report.append(f"- Num Seeds: {self.config.get('num_seeds', 'N/A')}")
report.append("")
# Overall statistics
report.append("## Overall Statistics\n")
accuracy = self.compute_accuracy()
report.append(f"- **Overall Accuracy**: {accuracy['overall']:.4f}")
report.append(f"- **Total Samples**: {len(self.results)}")
if self.stats:
report.append(f"- **Avg Confidence**: {self.stats.get('avg_confidence', 0):.4f}")
report.append(f"- **Avg Consistency**: {self.stats.get('avg_consistency', 0):.4f}")
report.append("")
# Per-class accuracy
report.append("## Per-Class Accuracy\n")
report.append("| Class | Accuracy |")
report.append("|-------|----------|")
for cls, acc in sorted(accuracy.items()):
if cls != "overall":
report.append(f"| {cls} | {acc:.4f} |")
report.append("")
# Consistency analysis
report.append("## Consistency Analysis\n")
cons_analysis = self.compute_consistency_analysis()
if cons_analysis:
report.append(f"- **Mean Consistency**: {cons_analysis['consistency_mean']:.4f}")
report.append(f"- **Consistency Std**: {cons_analysis['consistency_std']:.4f}")
report.append(f"- **High Consistency (≥0.8) Accuracy**: {cons_analysis['high_consistency_accuracy']:.4f}")
report.append(f"- **Low Consistency (<0.8) Accuracy**: {cons_analysis['low_consistency_accuracy']:.4f}")
report.append("\n### Accuracy by Consistency Bin\n")
report.append("| Consistency Range | Count | Accuracy |")
report.append("|-------------------|-------|----------|")
for bin_label, stats in cons_analysis["bin_statistics"].items():
report.append(f"| {bin_label} | {stats['count']} | {stats['accuracy']:.4f} |")
report.append("")
# Confusion matrix
report.append("## Confusion Matrix\n")
matrix, classes = self.compute_confusion_matrix()
if len(classes) > 0:
header = "| | " + " | ".join(classes) + " |"
separator = "|---" * (len(classes) + 1) + "|"
report.append(header)
report.append(separator)
for i, cls in enumerate(classes):
row = f"| **{cls}** | " + " | ".join(str(x) for x in matrix[i]) + " |"
report.append(row)
report.append("")
return "\n".join(report)
def save_report(self, output_path: str) -> None:
"""
Save report to file.
Args:
output_path: Path to save the report
"""
report = self.generate_report()
with open(output_path, "w", encoding="utf-8") as f:
f.write(report)
print(f"[SAVE] Report saved to: {output_path}")
def compare_experiments(
paths: List[str],
output_path: Optional[str] = None
) -> pd.DataFrame:
"""
Compare results across multiple experiments.
Args:
paths: List of paths to experiment result directories
output_path: Optional path to save comparison CSV
Returns:
DataFrame with comparison metrics
"""
comparison = []
for path in paths:
analyzer = SCResultsAnalyzer(path)
accuracy = analyzer.compute_accuracy()
cons = analyzer.compute_consistency_analysis()
comparison.append({
"experiment": os.path.basename(path),
"selection_criteria": analyzer.config.get("selection_criteria", "N/A"),
"num_samples": len(analyzer.results),
"accuracy": accuracy["overall"],
"avg_consistency": cons.get("consistency_mean", 0),
"high_cons_accuracy": cons.get("high_consistency_accuracy", 0),
})
df = pd.DataFrame(comparison)
if output_path:
df.to_csv(output_path, index=False)
print(f"[SAVE] Comparison saved to: {output_path}")
return df
# =============================================================================
# CLI Commands
# =============================================================================
def analyze(results_path: str) -> None:
"""
Analyze single experiment results.
Args:
results_path: Path to experiment results directory
"""
analyzer = SCResultsAnalyzer(results_path)
print("\n" + "=" * 60)
print("EXPERIMENT ANALYSIS")
print("=" * 60)
# Accuracy
accuracy = analyzer.compute_accuracy()
print(f"\nOverall Accuracy: {accuracy['overall']:.4f}")
print("\nPer-Class Accuracy:")
for cls, acc in sorted(accuracy.items()):
if cls != "overall":
print(f" {cls}: {acc:.4f}")
# Consistency
cons = analyzer.compute_consistency_analysis()
print(f"\nConsistency Analysis:")
print(f" Mean: {cons['consistency_mean']:.4f}")
print(f" Std: {cons['consistency_std']:.4f}")
print(f" High (≥0.8) Accuracy: {cons['high_consistency_accuracy']:.4f}")
print("=" * 60)
def report(results_path: str, output: str = None) -> None:
"""
Generate and optionally save analysis report.
Args:
results_path: Path to experiment results directory
output: Optional path to save report (default: results_path/report.md)
"""
analyzer = SCResultsAnalyzer(results_path)
if output is None:
output = os.path.join(results_path, "report.md")
analyzer.save_report(output)
print(f"Report generated: {output}")
def compare(*paths: str, output: str = None) -> None:
"""
Compare multiple experiment results.
Args:
paths: Paths to experiment result directories
output: Optional path to save comparison CSV
"""
df = compare_experiments(list(paths), output)
print("\n" + df.to_string(index=False))
if __name__ == "__main__":
Fire({
"analyze": analyze,
"report": report,
"compare": compare,
})

View File

@@ -1,73 +0,0 @@
# ==============================================================================
# Sleep Stage Classification - Confidence-based Queue Policy Experiment
# ==============================================================================
# ------------------------------------------------------------------------------
# Data Configuration
# ------------------------------------------------------------------------------
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
data_workers: 16
# ------------------------------------------------------------------------------
# Experiment Settings (to be overridden by CLI)
# ------------------------------------------------------------------------------
# These will be set via command line arguments
user_id: 5 # Target user (5 or 10)
shuffle_seed: 42 # Shuffle seed (42, 123, or 456)
# Queue and ICL settings
queue_size: 5
num_icl_shots: 5 # Number of ICL examples per agent
# Self-Consistency settings
num_sc_samples: 8 # Number of SC sampling agents
# Example pool selection: "out" (different users) or "in" (same user)
example_pool: "out"
# Process all samples (no sampling)
sample_rate: 1
# Tracking window size for rolling window accuracy
tracking_window: 20
# Model context window size
num_ctx: 15000
# Temperature for LLM
temperature: 0.0
# ------------------------------------------------------------------------------
# Queue Policy
# ------------------------------------------------------------------------------
queue_policy: "confidence"
# Options: "confidence", "consistency", "random"
# This config is for CONFIDENCE-based queue updates
# ------------------------------------------------------------------------------
# Model Configuration (8 agents)
# ------------------------------------------------------------------------------
models:
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
# ------------------------------------------------------------------------------
# Output Configuration
# ------------------------------------------------------------------------------
log_path: "/mnt/sting/ssum/sleepedf_sc_experiment/confidence"
# ------------------------------------------------------------------------------
# Sleep Stages
# ------------------------------------------------------------------------------
stages:
- W
- N1
- N2
- N3
- REM

View File

@@ -1,73 +0,0 @@
# ==============================================================================
# Sleep Stage Classification - Consistency-based Queue Policy Experiment
# ==============================================================================
# ------------------------------------------------------------------------------
# Data Configuration
# ------------------------------------------------------------------------------
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
data_workers: 16
# ------------------------------------------------------------------------------
# Experiment Settings (to be overridden by CLI)
# ------------------------------------------------------------------------------
# These will be set via command line arguments
user_id: 5 # Target user (5 or 10)
shuffle_seed: 42 # Shuffle seed (42, 123, or 456)
# Queue and ICL settings
queue_size: 5
num_icl_shots: 5 # Number of ICL examples per agent
# Self-Consistency settings
num_sc_samples: 8 # Number of SC sampling agents
# Example pool selection: "out" (different users) or "in" (same user)
example_pool: "out"
# Process all samples (no sampling)
sample_rate: 1
# Tracking window size for rolling window accuracy
tracking_window: 20
# Model context window size
num_ctx: 15000
# Temperature for LLM
temperature: 0.0
# ------------------------------------------------------------------------------
# Queue Policy
# ------------------------------------------------------------------------------
queue_policy: "consistency"
# Options: "confidence", "consistency", "random"
# This config is for CONSISTENCY-based queue updates
# ------------------------------------------------------------------------------
# Model Configuration (8 agents)
# ------------------------------------------------------------------------------
models:
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
# ------------------------------------------------------------------------------
# Output Configuration
# ------------------------------------------------------------------------------
log_path: "/mnt/sting/ssum/sleepedf_sc_experiment/consistency"
# ------------------------------------------------------------------------------
# Sleep Stages
# ------------------------------------------------------------------------------
stages:
- W
- N1
- N2
- N3
- REM

View File

@@ -1,86 +0,0 @@
# ==============================================================================
# Sleep Stage Classification - Queue Random Baseline Experiment
# ==============================================================================
#
# This is an ABLATION STUDY baseline:
# - Queue structure is maintained (size=5)
# - BUT all 5 elements are refreshed with random samples EVERY step
# - No cumulative learning or retention of good examples
#
# Purpose: Test whether performance gains come from:
# 1. Queue structure itself (using 5 ICL examples)
# 2. Cumulative learning (retaining high-quality examples over time)
#
# Expected Result:
# - If Queue Random ≈ Confidence/Consistency → Queue structure is the key
# - If Queue Random < Confidence/Consistency → Cumulative learning is the key
# ==============================================================================
# ------------------------------------------------------------------------------
# Data Configuration
# ------------------------------------------------------------------------------
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
data_workers: 16
# ------------------------------------------------------------------------------
# Experiment Settings (to be overridden by CLI)
# ------------------------------------------------------------------------------
user_id: 5 # Target user (5, 10, or 15)
shuffle_seed: 42 # Shuffle seed (42 or 123)
# Queue settings (same as other policies for fair comparison)
queue_size: 5 # Number of ICL example sets
num_icl_shots: 5 # Number of ICL examples per agent
# Self-Consistency settings
num_sc_samples: 8 # Number of SC sampling agents (same as other policies)
# Example pool selection: "out" (different users) or "in" (same user)
example_pool: "out"
# Process all samples (no sampling)
sample_rate: 1
# Tracking window size for rolling window accuracy
tracking_window: 20
# Model context window size
num_ctx: 15000
# Temperature for LLM
temperature: 0.0
# Sleep stages for classification
stages:
- "W"
- "N1"
- "N2"
- "N3"
- "REM"
# ------------------------------------------------------------------------------
# Queue Policy
# ------------------------------------------------------------------------------
queue_policy: "queue_random"
# This config is for QUEUE RANDOM baseline:
# - Queue structure exists (5 slots)
# - ALL slots are refreshed with random samples every step
# - No retention of good examples
# ------------------------------------------------------------------------------
# Output Configuration
# ------------------------------------------------------------------------------
log_path: "/mnt/sting/ssum/sleepedf_sc_experiment/queue_random"
# ------------------------------------------------------------------------------
# Model Configuration (8 agents for Self-Consistency)
# ------------------------------------------------------------------------------
models:
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b

View File

@@ -1,32 +0,0 @@
# PPGBP Recruiter (MAB-based example set selection) config
# Usage: python -m sc.run_recruiter sc/config/ppgbp_recruiter.yaml
# When model_path is null: use ground-truth scorer (arm_id -> score), no LLM; run gt_steps to test MAB convergence.
data_path: $PPGBP_DATA_ROOT
feature_root: $DATA_INDEX_ROOT/metadata_onehot # index.json + embeddings.npy
log_path: ./logs/ppgbp_recruiter
model_path: null # set to a HF model path to use real LLM + self_certainty
queue_size: 5
recruit_size: 2
n_way: 2
k_shot: 3
num_arms: 50
# When model_path is null: deterministic reward = Gaussian over arm index, peak at gt_best_arm_id
gt_steps: 200
gt_best_arm_id: 20
gt_sigma: 5.0
n_das: 5
delta: 0.05
epsilon: 0.11
temperature: 0.7
max_new_tokens: 256
num_seeds: 1
seed: 42
task_info: "Classify the subject from PPG-BP metadata."
classes_info: ["Normal", "Prehypertension", "Stage 1 hypertension", "Stage 2 hypertension" ]

View File

@@ -1,80 +0,0 @@
# ==============================================================================
# Sleep-EDF Self-Consistency Experiment Configuration
# ==============================================================================
# ------------------------------------------------------------------------------
# Data Configuration
# ------------------------------------------------------------------------------
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
data_workers: 16
# ------------------------------------------------------------------------------
# Experiment Settings
# ------------------------------------------------------------------------------
num_seeds: 1
num_examples: 1
# Sample rate: process every Nth sample for faster experiments
sample_rate: 10
# Example pool selection: "out" (different users) or "in" (same user)
example_pool: "out"
# Continuous mode: if True, process samples in order; if False, shuffle
continuous: true
# Queue size for example selection (capacity of the example queue)
queue_size: 5
# Model context window size
num_ctx: 15000
# ------------------------------------------------------------------------------
# Selection Criteria
# Available options:
# - out_random: Random selection from different users (baseline)
# - in_random: Random selection from same user (personalization baseline)
# - out_similar: Chronos-2 embedding similarity-based selection
# - out_metadata: Gower distance-based selection (gender, age)
# ------------------------------------------------------------------------------
selection_criteria: "out_random"
# ------------------------------------------------------------------------------
# out_similar Configuration (Chronos-2 Embedding)
# Uncomment when using out_similar selection criteria
# ------------------------------------------------------------------------------
# embedding_path: "/home/ssum/tsllm_personalization_icl/embeddings_full"
# ------------------------------------------------------------------------------
# out_metadata Configuration (Gower Distance)
# Uncomment when using out_metadata selection criteria
# ------------------------------------------------------------------------------
# metadata_path: "/home/ssum/tsllm_personalization_icl/preprocess/SC-subjects.xls"
# weight_gender: 1.0 # Gender distance weight (same=0, different=1)
# weight_age: 1.0 # Age distance weight (normalized: |age1-age2|/range)
# ------------------------------------------------------------------------------
# Self-Consistency Settings
# ------------------------------------------------------------------------------
# Number of sampling iterations for Self-Consistency
num_sc_samples: 5
temperature: 0.0
# ------------------------------------------------------------------------------
# Model Configuration
# ------------------------------------------------------------------------------
# Multiple Ollama instances provide model diversity even with T=0
models:
- hf:openai/gpt-oss-20b
# Add more entries for model pool (same model, multiple instances for self-consistency)
# - hf:/path/to/gpt-oss-20b
# ------------------------------------------------------------------------------
# Output Configuration
# ------------------------------------------------------------------------------
log_path: "/mnt/sting/ssum/sleepedf_sc_result_test"
# Previous experiment result paths (for reference):
# chronos_base: "/mnt/sting/ssum/sleepedf_chronos_base_result"
# out_random: "/mnt/sting/ssum/sleepedf_chronos_result_outrandom"
# in_random: "/mnt/sting/ssum/sleepedf_chronos_result"

View File

@@ -1,142 +0,0 @@
"""
Base agent for single-turn LLM inference with structured JSON parsing.
Each invoke call is stateless: system prompt + user prompt -> (text, logits).
Message format: {"role": "user"|"assistant"|"system", "content": str}
"""
import os
import re
import sys
import json
from typing import List, Optional, Tuple
import torch
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.model import AsyncModelPool, ChatMessage
class Agent:
"""Wraps an AsyncModelPool for single-turn inference."""
def __init__(
self,
name: str,
model_pool: AsyncModelPool,
log_path: str,
system_message: str = "",
) -> None:
self.name = name
self.model_pool = model_pool
self.system_message = system_message
# Logging directories
self.root_log_path = log_path
os.makedirs(log_path, exist_ok=True)
self.agent_log_path = os.path.join(log_path, name)
os.makedirs(self.agent_log_path, exist_ok=True)
# -----------------------------------------------------------------
# Logging
# -----------------------------------------------------------------
def log(self, message: ChatMessage) -> None:
"""Append a message to the global and agent-local log."""
role = message.get("role", "unknown")
content = message.get("content", "").strip()
entry = f"[{self.name}] [{role}]\n{content}\n\n\n"
for path in (
os.path.join(self.root_log_path, "log.txt"),
os.path.join(self.agent_log_path, "log.txt"),
):
with open(path, "a", encoding="utf-8") as f:
f.write(entry)
# -----------------------------------------------------------------
# JSON helpers
# -----------------------------------------------------------------
@staticmethod
def _clean_json_text(text: str) -> str:
"""Normalise common LLM quirks so the string is valid JSON."""
text = text.strip()
text = text.replace("\u2018", "'").replace("\u2019", "'") # smart single quotes
text = text.replace("\u201c", "'").replace("\u201d", "'") # smart double quotes
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text) # escape stray backslashes
text = re.sub(r",\s*}", "}", text) # trailing comma in object
text = re.sub(r",\s*]", "]", text) # trailing comma in array
text = "".join(ch for ch in text if ch.isprintable()) # strip control chars
text = text.replace("][", ",") # merge adjacent arrays
return text
def safe_parse_json(self, text: str) -> Optional[dict]:
"""Best-effort extraction of a JSON object from an LLM response."""
if not text:
return None
text = text.strip()
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match and not text.endswith("}"):
match = re.search(r"\{.*\}", text + "}", re.DOTALL)
if match:
try:
return json.loads(self._clean_json_text(match.group(0)))
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
print("[!] JSON parse failed: no object found")
return None
# -----------------------------------------------------------------
# Response validation
# -----------------------------------------------------------------
async def validate_response(
self, response: Optional[dict], fields: List[str]
) -> Optional[dict]:
"""Validate that *response* is a dict containing all *fields*; retry once if not."""
if not isinstance(response, dict) or not all(f in response for f in fields):
print("[!] JSON validation failed. Retrying...")
text, _ = await self.invoke(
"Failed to parse the JSON from the previous response. Please try again.",
)
response = self.safe_parse_json(text)
if not isinstance(response, dict) or not all(f in response for f in fields):
print("[!] Retry failed.")
return None
return response
# -----------------------------------------------------------------
# Core invoke
# -----------------------------------------------------------------
async def invoke(
self,
content: str,
) -> Tuple[Optional[str], Optional[torch.Tensor]]:
"""
Single-turn call: [system_message] + user content -> (text, logits).
Returns:
text: Generated reply string (or None on error).
logits: Tensor of shape (1, num_tokens, vocab_size) on CPU (or None).
"""
messages: List[ChatMessage] = []
if self.system_message:
messages.append({"role": "system", "content": self.system_message})
user_msg: ChatMessage = {"role": "user", "content": content}
messages.append(user_msg)
try:
text, logits = await self.model_pool.invoke(messages)
assistant_msg: ChatMessage = {"role": "assistant", "content": text}
self.log(user_msg)
self.log(assistant_msg)
return text.strip(), logits
except Exception as e:
print(f"[Error] invoke failed: {e}")
return None, None

View File

@@ -1,137 +0,0 @@
import os
import asyncio
class AgentPool:
def __init__(self, log_path):
self.agents = {}
os.makedirs(log_path, exist_ok=True)
self.log_path = log_path
def add_agent(self, agent):
self.agents[agent.index] = agent
def log_summary(self, message, print_log=True):
path = os.path.join(self.log_path, "summary.txt")
with open(path, "a", encoding="utf-8") as f:
f.write(f"{message}\n")
if print_log:
print(message)
def get_last_responses(self):
responses = {}
for index, agent in self.agents.items():
response = agent.get_last_response()
if response:
responses[index] = response
return responses
def vote(self, responses, mode="majority_vote"):
"""
Vote for the final answer from agent responses.
Args:
responses: Dictionary mapping agent indices to response dicts
Format: {"agent_name": {"ANSWER": "...", "CONFIDENCE": 0.8, ...}, ...}
mode: Voting mode - "majority_vote", "highest_confidence", or "confidence_vote"
Returns:
The winning answer string
"""
if not responses:
return None
if mode == "highest_confidence":
# Find response with highest confidence
best_response = max(responses.values(), key=lambda x: x.get("CONFIDENCE", 0))
return best_response.get("ANSWER")
elif mode == "majority_vote":
# Count votes for each answer
answer_cnt = {}
for response in responses.values():
answer = response.get("ANSWER")
if answer:
answer_cnt[answer] = answer_cnt.get(answer, 0) + 1
if not answer_cnt:
return None
max_val = max(answer_cnt.values())
max_key = [k for k, v in answer_cnt.items() if v == max_val][0]
return max_key
elif mode == "confidence_vote":
# Weight votes by confidence
cnts = {}
for response in responses.values():
answer = response.get("ANSWER")
confidence = response.get("CONFIDENCE", 0)
if answer:
cnts[answer] = cnts.get(answer, 0) + confidence
if not cnts:
return None
return max(cnts, key=cnts.get)
else:
raise ValueError(f"Invalid mode: {mode}")
async def run_parallel_interpretation(self):
tasks = []
for _, agent in self.agents.items():
tasks.append(asyncio.create_task(agent.interpret()))
results = await asyncio.gather(*tasks)
print(results)
if None in results:
self.log_summary(f"[Error] Failed to interpret")
return None
for _, agent in self.agents.items():
agent.update_memory()
responses = self.get_last_responses()
for index, response in responses.items():
answer = response.get("ANSWER", "UNKNOWN")
confidence = response.get("CONFIDENCE", 0.0)
self.log_summary(f"[Interpretation] <{index}> provided answer: {answer} with confidence: {confidence}")
voted_result = self.vote(responses, mode="majority_vote")
queue_idcs = self.filter_examples(responses)
# Calculate confidence and consistency for the voted result
all_answers = [r.get("ANSWER") for r in responses.values()]
majority_responses = [r for r in responses.values() if r.get("ANSWER") == voted_result]
# Avg confidence of majority answer
avg_confidence = sum(r.get("CONFIDENCE", 0) for r in majority_responses) / len(majority_responses) if majority_responses else 0
# Consistency = ratio of majority votes
consistency = len(majority_responses) / len(responses) if responses else 0
return voted_result, queue_idcs, avg_confidence, consistency, responses
def filter_examples(self, responses):
"""
Filter examples based on highest confidence responses.
Args:
responses: Dictionary mapping agent indices to response dicts
Format: {"agent_name": {"ANSWER": "...", "CONFIDENCE": 0.8, "_example_idx": 0}, ...}
Returns:
List of queue indices with highest confidence
"""
queue_idcs = []
max_confidence = 0
for index, response in responses.items():
confidence = response.get("CONFIDENCE", 0)
if confidence > max_confidence:
max_confidence = confidence
queue_idcs = [index]
elif confidence == max_confidence:
queue_idcs.append(index)
# Debug: 선택 이유 출력
print(f"\n[Selection] Confidence Summary:")
for index, response in sorted(responses.items()):
conf = response.get("CONFIDENCE", 0)
ans = response.get("ANSWER", "?")
marker = " ★ SELECTED" if index in queue_idcs else ""
print(f" Case #{index}: {ans} (conf: {conf}){marker}")
print(f"[Selection] Max Confidence: {max_confidence} → Selected Case(s): {queue_idcs}")
return queue_idcs

View File

@@ -1,630 +0,0 @@
"""
Data loader for personalized ICL experiments.
Loads a target user's test split and example data from all other users.
Expects the following directory structure:
<path>/
info.json # dataset metadata (task, classes, features)
<user_id>/
1/ # examples (train split) - HuggingFace dataset on disk
2/ # test split - HuggingFace dataset on disk
<other_user>/
1/
...
--------------------------------------------------------------------------------
Integrating a new dataset
--------------------------------------------------------------------------------
The pipeline expects:
1) Directory layout (for DataLoader / ShuffledDataLoader):
- <data_path>/info.json
- <data_path>/<user_id>/1/ and <data_path>/<user_id>/2/ (HuggingFace datasets saved with datasets.Dataset.save_to_disk)
2) info.json schema:
- "task": str (e.g. "Sleep stage classification from EEG")
- "class": dict mapping label -> description (e.g. {"W": "Wake", "N1": "Non-REM 1", ...})
- "feature": str (sensor/feature description for the LLM)
3) Each sample (in test and example datasets) must be a dict with:
- "label": str (class name, must be one of the keys in info.json["class"])
- "features": dict of str -> value (e.g. {"Fpz-Cz": "0.12, -0.05, ...", "Pz-Oz": "..."}); values are formatted for the prompt
Option A - Convert your dataset to this format:
- Create info.json and per-user dirs; build HuggingFace Dataset with columns "label" and "features", then save_to_disk to <user_id>/1 and <user_id>/2.
- Use prepare_dataset_for_sc() in this module to write from in-memory lists.
Option B - Use in-memory data without changing directory layout:
- Use InMemoryDataLoader(metadata, test_samples, example_samples) which implements the same interface as DataLoader.
- Use it in run_sc by constructing it and passing to run_single_task; for run_confidence_based / run_consistency_based / run_sc_queue_random you currently need the directory layout (or extend those runners to accept a loader instance).
"""
import os
import json
from glob import glob
from typing import Any, Dict, List, Optional, Union
import datasets
class DataLoader:
"""Loads a target user's test set and ICL examples from other users."""
def __init__(
self,
path: Optional[str] = None,
user_id: Optional[str] = None,
shuffle: bool = False,
seed: int = 0,
data_path: Optional[str] = None,
**kwargs: Any,
) -> None:
# Support both path= and data_path= for compatibility with run_sc
root = path or data_path
if root is None or user_id is None:
raise ValueError("path (or data_path) and user_id are required")
metadata_path = os.path.join(root, "info.json")
target_user_path = os.path.join(root, user_id, "2")
if not os.path.exists(metadata_path) or not os.path.exists(target_user_path):
return
# Metadata
with open(metadata_path, "r", encoding="utf-8") as f:
self.metadata: Dict[str, Any] = json.load(f)
# Test set for the target user
self.test_dataset = datasets.load_from_disk(target_user_path)
# Example set: train splits from all *other* users
example_paths = [
os.path.join(user_path, "1")
for user_path in glob(os.path.join(root, "*"))
if os.path.isdir(user_path) and os.path.basename(user_path) != user_id
]
example_parts = [datasets.load_from_disk(p) for p in example_paths]
if example_parts:
self.example_dataset = datasets.concatenate_datasets(example_parts)
else:
self.example_dataset = datasets.Dataset.from_list([])
# Shuffle
if shuffle:
self.test_dataset = self.test_dataset.shuffle(seed=seed)
self.example_dataset = self.example_dataset.shuffle(seed=seed)
def __len__(self) -> int:
return len(self.test_dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
return self.test_dataset[idx]
def __iter__(self):
yield from self.test_dataset
def get_examples(self) -> datasets.Dataset:
return self.example_dataset
def get_metadata(self) -> Dict[str, Any]:
return self.metadata
def get_sensor_info(self) -> str:
return self.metadata["feature"]
def get_task_info(self) -> str:
"""Return a formatted string describing the task and its classes."""
classes_info = "\n".join(
f" - {k}: {v}" for k, v in self.metadata["class"].items()
)
return f"**Task**:\n{self.metadata['task']}\n\n**Classes**:\n{classes_info}"
def get_classes_info(self) -> List[str]:
return list(self.metadata["class"].keys())
class InMemoryDataLoader:
"""
DataLoader interface backed by in-memory lists. Use this to run the SC
pipeline on a new dataset without writing the directory layout to disk.
Implements the same interface as DataLoader: __len__, __getitem__, __iter__,
get_examples(), get_metadata(), get_sensor_info(), get_task_info(), get_classes_info().
"""
def __init__(
self,
metadata: Dict[str, Any],
test_samples: List[Dict[str, Any]],
example_samples: List[Dict[str, Any]],
) -> None:
"""
Args:
metadata: Must have "task", "class" (dict label -> description), "feature" (str).
test_samples: List of dicts with "label" and "features" (dict).
example_samples: Same schema; used as ICL example pool.
"""
self.metadata = metadata
self.test_dataset = datasets.Dataset.from_list(test_samples)
self.example_dataset = datasets.Dataset.from_list(example_samples)
def __len__(self) -> int:
return len(self.test_dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
return self.test_dataset[idx]
def __iter__(self):
yield from self.test_dataset
def get_examples(self) -> datasets.Dataset:
return self.example_dataset
def get_metadata(self) -> Dict[str, Any]:
return self.metadata
def get_sensor_info(self) -> str:
return self.metadata["feature"]
def get_task_info(self) -> str:
classes_info = "\n".join(
f" - {k}: {v}" for k, v in self.metadata["class"].items()
)
return f"**Task**:\n{self.metadata['task']}\n\n**Classes**:\n{classes_info}"
def get_classes_info(self) -> List[str]:
return list(self.metadata["class"].keys())
def prepare_dataset_for_sc(
output_path: str,
metadata: Dict[str, Any],
per_user_splits: Dict[str, Dict[str, List[Dict[str, Any]]]],
) -> None:
"""
Write a new dataset to the directory layout expected by DataLoader and
ShuffledDataLoader. Call this once; then point config data_path to output_path.
Args:
output_path: Root directory to create (e.g. /path/to/MyDataset).
metadata: info.json contents: "task", "class" (dict), "feature" (str).
per_user_splits: { user_id: { "train": [samples], "test": [samples] } }.
Each sample is a dict with "label" and "features".
"""
os.makedirs(output_path, exist_ok=True)
with open(os.path.join(output_path, "info.json"), "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
for user_id, splits in per_user_splits.items():
user_dir = os.path.join(output_path, str(user_id))
os.makedirs(os.path.join(user_dir, "1"), exist_ok=True)
os.makedirs(os.path.join(user_dir, "2"), exist_ok=True)
train_ds = datasets.Dataset.from_list(splits.get("train", []))
test_ds = datasets.Dataset.from_list(splits.get("test", []))
train_ds.save_to_disk(os.path.join(user_dir, "1"))
test_ds.save_to_disk(os.path.join(user_dir, "2"))
# PPGBP data loader
"""
PPGBP Dataset Loader
A data loader/iterator for the PPG-BP dataset that reads xlsx metadata
and corresponding signal text files, with optional batching support.
Metadata embeddings live under PPGBP_METADATA_EMBEDDINGS_ROOT (one .npy per
subject ID). If missing, they are generated automatically in PPGBPLoader.__init__
using SBERT. Set PPGBP_METADATA_EMBEDDINGS_ROOT in .env or environment.
"""
import os
import sys
import numpy as np
import pandas as pd
from typing import Iterator, Tuple, Dict, List, Optional, Union
try:
from dotenv import load_dotenv
_ppgbp_env_loaded = False
def _ensure_ppgbp_env():
global _ppgbp_env_loaded
if not _ppgbp_env_loaded:
load_dotenv()
_ppgbp_env_loaded = True
except ImportError:
def _ensure_ppgbp_env():
pass
def _get_ppgbp_embedding_root() -> str:
_ensure_ppgbp_env()
root = os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT")
if not root:
raise ValueError(
"PPGBP_METADATA_EMBEDDINGS_ROOT is not set. Set it in .env or environment "
"(e.g. PPGBP_METADATA_EMBEDDINGS_ROOT=/path/to/metadata_embeddings)."
)
return root
def _get_ppgbp_embedding_root_optional() -> Optional[str]:
"""Return PPGBP_METADATA_EMBEDDINGS_ROOT if set, else None. Used when embeddings are not required (e.g. recruiter with feature_root)."""
_ensure_ppgbp_env()
root = os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT") or ""
return root.strip() or None
def _get_missing_embedding_ids(embedding_root: str, subject_ids: np.ndarray) -> Optional[List[int]]:
"""
Return list of subject_ids that do not have a .npy file in embedding_root.
Return None if embedding_root is not a directory (so caller can create and generate).
"""
if not os.path.isdir(embedding_root):
return None
missing = []
for sid in subject_ids:
path = os.path.join(embedding_root, f"{int(sid)}.npy")
if not os.path.isfile(path):
missing.append(int(sid))
return missing if missing else []
def _check_embedding_root_populated(embedding_root: str, subject_ids: np.ndarray) -> None:
"""Raise if embedding_root is missing or does not contain .npy for every subject_id."""
missing_or_none = _get_missing_embedding_ids(embedding_root, subject_ids)
if missing_or_none is None:
raise FileNotFoundError(
f"Metadata embeddings root is not a directory or does not exist: {embedding_root}."
)
if missing_or_none:
raise FileNotFoundError(
f"Metadata embeddings root {embedding_root} is missing .npy files for subject IDs: "
f"{missing_or_none[:10]}{'...' if len(missing_or_none) > 10 else ''}."
)
def _generate_and_save_ppgbp_metadata_embeddings(
embedding_root: str,
subject_path: str,
subject_ids: np.ndarray,
) -> None:
"""
Generate SBERT metadata embeddings for the given subject IDs and save as
{subject_id}.npy under embedding_root. Uses tqdm for progress.
"""
try:
from tqdm import tqdm
except ImportError:
tqdm = lambda x, **kw: x # noqa: E731
# Lazy import to keep SBERT/gen_plot deps out of normal import path
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from analysis.user_similarity.sbert_metadata_ppgbp.gen_plot import (
load_subject_metadata,
SBERT_Metadata,
_normalize_sex_for_sbert,
_normalize_hypertension_for_sbert,
)
os.makedirs(embedding_root, exist_ok=True)
subject_metadata = load_subject_metadata(subject_path)
embedder = SBERT_Metadata()
ids_to_generate = [int(sid) for sid in subject_ids]
for user_id in tqdm(ids_to_generate, desc="PPGBP metadata embeddings", unit="subject"):
user_id_str = str(user_id)
meta = subject_metadata.get(user_id_str, {})
sex = _normalize_sex_for_sbert(meta.get("sex"))
age = meta.get("age")
height = meta.get("height")
weight = meta.get("weight")
sbp = meta.get("sbp")
dbp = meta.get("dbp")
hr = meta.get("hr")
bmi = meta.get("bmi")
hypertension = _normalize_hypertension_for_sbert(meta.get("hypertension"))
if age is not None and isinstance(age, float) and np.isnan(age):
age = None
emb = embedder.compute_embedding_from_metadata(
[sex], [age], [height], [weight], [sbp], [dbp], [hr], [bmi], [hypertension]
)
path = os.path.join(embedding_root, f"{user_id}.npy")
np.save(path, emb[0])
def load_ppgbp_metadata_embedding(embedding_root: str, subject_id: int) -> np.ndarray:
"""Load a single metadata embedding for subject_id from embedding_root (utility)."""
path = os.path.join(embedding_root, f"{int(subject_id)}.npy")
if not os.path.isfile(path):
raise FileNotFoundError(f"Metadata embedding not found: {path}")
return np.load(path)
def save_ppgbp_metadata_embeddings(
embedding_root: str,
subject_path: str,
subject_ids: Optional[np.ndarray] = None,
) -> None:
"""
Generate and save PPGBP metadata embeddings under embedding_root (one .npy per subject).
If subject_ids is None, uses all subject IDs present in the metadata file at subject_path.
Uses tqdm for progress.
"""
if subject_ids is not None and len(subject_ids) == 0:
return
# If subject_ids is None we need to get all from metadata; _generate_* expects an array
if subject_ids is None:
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from analysis.user_similarity.sbert_metadata_ppgbp.gen_plot import load_subject_metadata
subject_metadata = load_subject_metadata(subject_path)
subject_ids = np.array(sorted(int(k) for k in subject_metadata.keys()))
_generate_and_save_ppgbp_metadata_embeddings(embedding_root, subject_path, subject_ids)
class PPGBPLoader:
"""
Data loader for the PPG-BP dataset.
Reads metadata from xlsx file and corresponding PPG signal text files.
Supports iteration over single samples or batches.
Subject list is built from 0_subject/ and matched with metadata;
train/test is 80/20 at subject level (random, fixed by seed).
"""
COLUMN_MAPPING = {
"Sex(M/F)": "sex",
"Age(year)": "age",
"Systolic Blood Pressure(mmHg)": "sysbp",
"Diastolic Blood Pressure(mmHg)": "diasbp",
"Heart Rate(b/m)": "hr",
"BMI(kg/m^2)": "bmi"
}
def __init__(
self,
base_dir: str,
split: str = 'all',
batch_size: Optional[int] = None,
shuffle: bool = False,
num_segments: int = 3,
seed: int = 42,
return_metadata_embeddings: bool = False,
):
self.base_dir = base_dir
self.split = split
self.batch_size = batch_size
self.shuffle = shuffle
self.num_segments = num_segments
self.seed = seed
self.return_metadata_embeddings = return_metadata_embeddings
self.rng = np.random.default_rng(seed)
self.signal_dir = os.path.join(base_dir, "0_subject")
self.xlsx_path = self._find_xlsx_file()
self.metadata = self._load_metadata()
self._all_subject_ids = self._get_unified_subject_ids()
self._train_ids, self._test_ids = self._get_train_test_split(self._all_subject_ids)
self.subject_ids = self._get_split_ids()
self.metadata = self.metadata[self.metadata['subject_ID'].isin(self.subject_ids)]
self._embedding_root = _get_ppgbp_embedding_root_optional()
if self._embedding_root is not None:
missing = _get_missing_embedding_ids(self._embedding_root, self._all_subject_ids)
if len(self._all_subject_ids) > 0 and (missing is None or len(missing) > 0):
_generate_and_save_ppgbp_metadata_embeddings(
self._embedding_root, self.xlsx_path, self._all_subject_ids
)
_check_embedding_root_populated(self._embedding_root, self.subject_ids)
self._current_idx = 0
self._indices = np.arange(len(self.subject_ids))
if self.shuffle:
self.rng.shuffle(self._indices)
def _find_xlsx_file(self) -> str:
for f in os.listdir(self.base_dir):
if f.endswith('.xlsx'):
return os.path.join(self.base_dir, f)
raise FileNotFoundError(f"No xlsx file found in {self.base_dir}")
def _load_metadata(self) -> pd.DataFrame:
# Use header=1: xlsx has a title row, then row with column names (subject_ID, etc.)
df = pd.read_excel(self.xlsx_path, header=1)
df = df.rename(columns=self.COLUMN_MAPPING)
df = df.fillna(0)
return df
def _get_unified_subject_ids(self) -> np.ndarray:
"""Subject IDs present in both 0_subject/ (from listdir) and metadata."""
if not os.path.isdir(self.signal_dir):
return np.array([], dtype=np.int64)
# Files are named {subject_id}_{segment}.txt
ids_from_files = set()
for f in os.listdir(self.signal_dir):
if f.endswith(".txt"):
try:
sid = int(f.replace(".txt", "").split("_")[0])
ids_from_files.add(sid)
except (ValueError, IndexError):
continue
meta_ids = set(self.metadata["subject_ID"].astype(int).values)
unified = sorted(ids_from_files & meta_ids)
return np.array(unified)
def _get_train_test_split(self, ids: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""80% train, 20% test at subject level; shuffle fixed by seed."""
if len(ids) == 0:
return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
perm = self.rng.permutation(len(ids))
ids_perm = ids[perm]
n_train = max(1, int(round(0.8 * len(ids))))
train_ids = ids_perm[:n_train]
test_ids = ids_perm[n_train:]
return train_ids, test_ids
def _get_split_ids(self) -> np.ndarray:
if self.split == "train":
return self._train_ids
elif self.split == "test":
return self._test_ids
elif self.split == "all":
return self._all_subject_ids
else:
raise ValueError(f"Unknown split: {self.split}. Use 'train', 'test', or 'all'.")
@property
def train_ids(self) -> np.ndarray:
"""Subject IDs in the train split (80%)."""
return self._train_ids
@property
def test_ids(self) -> np.ndarray:
"""Subject IDs in the test split (20%)."""
return self._test_ids
def _load_signal(self, subject_id: int) -> np.ndarray:
segments = []
for s in range(1, self.num_segments + 1):
filepath = os.path.join(self.signal_dir, f"{subject_id}_{s}.txt")
signal = pd.read_csv(filepath, sep='\t', header=None)
signal = signal.values.squeeze()
if len(signal) > 1:
signal = signal[:-1]
segments.append(signal)
return np.array(segments, dtype=object)
def _get_metadata_dict(self, subject_id: int) -> Dict:
row = self.metadata[self.metadata['subject_ID'] == subject_id].iloc[0]
return row.to_dict()
def _load_metadata_embedding(self, subject_id: int) -> np.ndarray:
return load_ppgbp_metadata_embedding(self._embedding_root, subject_id)
def __len__(self) -> int:
return len(self.subject_ids)
def __iter__(self) -> 'PPGBPLoader':
self._current_idx = 0
if self.shuffle:
self.rng.shuffle(self._indices)
return self
def __next__(self) -> Union[
Tuple[Dict, np.ndarray],
Tuple[Dict, np.ndarray, np.ndarray],
Tuple[List[Dict], List[np.ndarray]],
Tuple[List[Dict], List[np.ndarray], List[np.ndarray]],
]:
if self._current_idx >= len(self.subject_ids):
raise StopIteration
if self.batch_size is None:
idx = self._indices[self._current_idx]
subject_id = self.subject_ids[idx]
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
self._current_idx += 1
if self.return_metadata_embeddings and self._embedding_root is not None:
return metadata, signal, self._load_metadata_embedding(subject_id)
return metadata, signal
else:
end_idx = min(self._current_idx + self.batch_size, len(self.subject_ids))
batch_indices = self._indices[self._current_idx:end_idx]
metadata_list = []
signal_list = []
embedding_list = [] if (self.return_metadata_embeddings and self._embedding_root is not None) else None
for idx in batch_indices:
subject_id = self.subject_ids[idx]
metadata_list.append(self._get_metadata_dict(subject_id))
signal_list.append(self._load_signal(subject_id))
if self.return_metadata_embeddings and self._embedding_root is not None:
embedding_list.append(self._load_metadata_embedding(subject_id))
self._current_idx = end_idx
if embedding_list is not None:
return metadata_list, signal_list, embedding_list
return metadata_list, signal_list
def __getitem__(self, idx: int) -> Union[Tuple[Dict, np.ndarray], Tuple[Dict, np.ndarray, np.ndarray]]:
if idx < 0 or idx >= len(self.subject_ids):
raise IndexError(f"Index {idx} out of range")
subject_id = self.subject_ids[idx]
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
if self.return_metadata_embeddings and self._embedding_root is not None:
return metadata, signal, self._load_metadata_embedding(subject_id)
return metadata, signal
def get_by_subject_id(self, subject_id: int) -> Union[Tuple[Dict, np.ndarray], Tuple[Dict, np.ndarray, np.ndarray]]:
if subject_id not in self.subject_ids:
raise ValueError(f"Subject ID {subject_id} not found in this split.")
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
if self.return_metadata_embeddings and self._embedding_root is not None:
return metadata, signal, self._load_metadata_embedding(subject_id)
return metadata, signal
def reset(self):
self._current_idx = 0
if self.shuffle:
self.rng.shuffle(self._indices)
def iter_batches(
self, batch_size: int
) -> Iterator[Union[Tuple[List[Dict], List[np.ndarray]], Tuple[List[Dict], List[np.ndarray], List[np.ndarray]]]]:
indices = np.arange(len(self.subject_ids))
if self.shuffle:
self.rng.shuffle(indices)
for start_idx in range(0, len(self.subject_ids), batch_size):
end_idx = min(start_idx + batch_size, len(self.subject_ids))
batch_indices = indices[start_idx:end_idx]
metadata_list = []
signal_list = []
embedding_list = [] if (self.return_metadata_embeddings and self._embedding_root is not None) else None
for idx in batch_indices:
subject_id = self.subject_ids[idx]
metadata_list.append(self._get_metadata_dict(subject_id))
signal_list.append(self._load_signal(subject_id))
if self.return_metadata_embeddings and self._embedding_root is not None:
embedding_list.append(self._load_metadata_embedding(subject_id))
if embedding_list is not None:
yield metadata_list, signal_list, embedding_list
else:
yield metadata_list, signal_list
def get_loaders(
base_dir: str,
batch_size: Optional[int] = None,
shuffle_train: bool = True,
seed: int = 42
) -> Tuple[PPGBPLoader, PPGBPLoader]:
"""Convenience function to get train and test loaders (80/20 split)."""
train_loader = PPGBPLoader(
base_dir=base_dir, split="train",
batch_size=batch_size, shuffle=shuffle_train, seed=seed
)
test_loader = PPGBPLoader(
base_dir=base_dir, split="test",
batch_size=batch_size, shuffle=False, seed=seed
)
return train_loader, test_loader
if __name__ == "__main__":
print("PPGBPLoader ready to use.")

View File

@@ -1,188 +0,0 @@
from collections import deque
import random
class Queue:
def __init__(self, dataset_size, capacity=5):
self.dataset_size = dataset_size
self.classes = list(dataset_size.keys())
initial_cases = [
[random.choice(dataset_size[cls]) for cls in self.classes]
for _ in range(capacity)
]
self._queue = deque(initial_cases, maxlen=capacity)
self._input_time = {tuple(case): 0 for case in initial_cases}
self._current_time = 0
self._eviction_history = []
self._usage_count = {tuple(case): 0 for case in initial_cases}
def set_current_time(self, sample_idx: int):
self._current_time = sample_idx
def push(self, index):
if list(index) in [list(c) for c in self._queue]:
return None
evicted = None
if len(self._queue) == self._queue.maxlen:
evicted = self._queue.popleft()
self._record_eviction_stats(evicted)
self._queue.append(list(index))
self._register_stats(index)
return evicted
def pop(self):
return self._queue.popleft() if self._queue else None
def __iter__(self):
"""
Make Queue iterable by yielding all elements in the queue.
This allows the queue to be used in for loops:
for ex_index in ex_queue:
# process ex_index
"""
for idx in self._queue:
yield idx
def _register_stats(self, case):
key = tuple(case)
self._input_time[key] = self._current_time
self._usage_count[key] = 0
def _record_eviction_stats(self, case):
key = tuple(case)
if key in self._input_time:
duration = self._current_time - self._input_time[key]
self._eviction_history.append({
"case": case,
"duration": duration,
"usage": self._usage_count.get(key, 0),
"evicted_at": self._current_time
})
del self._input_time[key]
if key in self._usage_count:
del self._usage_count[key]
def update_by_confidence(self, confidence_map):
if not confidence_map:
return None
current_items = list(self._queue)
items_with_score = []
for i, item in enumerate(current_items):
items_with_score.append({
"item": item,
"score": confidence_map.get(i, -1.0)
})
items_with_score.sort(key=lambda x: x["score"], reverse=True)
self._queue = deque([x["item"] for x in items_with_score], maxlen=self._queue.maxlen)
evicted = None
if len(self._queue) == self._queue.maxlen:
evicted = self._queue.pop()
self._record_eviction_stats(evicted)
new_case = [random.choice(self.dataset_size[cls]) for cls in self.classes]
self._queue.append(new_case)
self._register_stats(new_case)
return evicted
def update_by_consistency(self, consistency_map):
"""Update queue based on consistency scores (agreement ratio among agents)."""
if not consistency_map:
return None
current_items = list(self._queue)
items_with_score = []
for i, item in enumerate(current_items):
items_with_score.append({
"item": item,
"score": consistency_map.get(i, -1.0)
})
items_with_score.sort(key=lambda x: x["score"], reverse=True)
self._queue = deque([x["item"] for x in items_with_score], maxlen=self._queue.maxlen)
evicted = None
if len(self._queue) == self._queue.maxlen:
evicted = self._queue.pop()
self._record_eviction_stats(evicted)
new_case = [random.choice(self.dataset_size[cls]) for cls in self.classes]
self._queue.append(new_case)
self._register_stats(new_case)
return evicted
def increment_usage(self, queue_indices):
"""Increment usage count for specified queue indices."""
for idx in queue_indices:
if 0 <= idx < len(self._queue):
key = tuple(self._queue[idx])
self._usage_count[key] = self._usage_count.get(key, 0) + 1
def get_instance_id(self):
return id(self)
def get_state_with_stats(self):
result = []
for case in self._queue:
key = tuple(case)
age = self._current_time - self._input_time.get(key, self._current_time)
result.append({
"case": case,
"age": age,
"usage": self._usage_count.get(key, 0)
})
return result
def get_survival_stats(self):
return self._eviction_history
def get_survival_summary(self):
if not self._eviction_history:
return {
"total_evicted": 0,
"avg_survival": 0,
"max_survival": 0,
"avg_usage": 0
}
durations = [x["duration"] for x in self._eviction_history]
usages = [x["usage"] for x in self._eviction_history]
return {
"total_evicted": len(durations),
"avg_survival": sum(durations) / len(durations),
"max_survival": max(durations),
"avg_usage": sum(usages) / len(usages) if usages else 0
}
def update_with_recruiter(self, results, new_example_sets, L: int = 1):
"""
Evict L lowest-scoring sets (by self-certainty from results) and push L new sets.
results: list of (case, response_dict_or_text, self_certainty_score) in queue order.
new_example_sets: list of L new cases (each case = list of indices, same as queue items).
Lower self_certainty = better; we evict the L with highest (worst) score.
"""
if not results or L <= 0:
return
current = list(self._queue)
if len(current) == 0:
for case in new_example_sets[: self._queue.maxlen]:
self._queue.append(list(case))
self._register_stats(case)
return
scores = []
for i, item in enumerate(current):
sc = -1.0
if i < len(results) and len(results[i]) >= 3:
sc = float(results[i][2])
scores.append((i, sc))
scores.sort(key=lambda x: x[1], reverse=True)
to_evict_idx = {scores[j][0] for j in range(min(L, len(scores)))}
new_queue = [current[i] for i in range(len(current)) if i not in to_evict_idx]
for idx in sorted(to_evict_idx, reverse=True):
if idx < len(current):
self._record_eviction_stats(current[idx])
for case in new_example_sets[:L]:
if len(new_queue) >= self._queue.maxlen:
break
new_queue.append(list(case))
self._register_stats(case)
self._queue = deque(new_queue, maxlen=self._queue.maxlen)

View File

@@ -1,162 +0,0 @@
"""
Self-Consistency Agent for Sleep Stage Classification
This module implements an agent that uses Self-Consistency methodology:
- Sample N times with the same prompt (configurable temperature)
- Output REASON, CONFIDENCE, ANSWER in JSON format
- Aggregate final answer via Majority Voting
The agent extends the base Agent class and adds Self-Consistency
specific functionality including multi-sampling and voting.
"""
import os
import sys
import json
import asyncio
from typing import List, Dict, Any, Optional
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from sc.core.agent import Agent
class JudgeAgent(Agent):
def __init__(
self,
name: str,
index: int,
model_pool,
task_info: str,
classes_info: List[str],
sensor_info: str,
sample: Dict[str, Any],
examples: List[Dict[str, Any]],
log_path: str,
):
super().__init__(
name=name + f"_{index}",
model_pool=model_pool,
log_path=log_path,
)
self.task_info = task_info
self.classes_info = classes_info
self.sensor_info = sensor_info
self.sample = sample
self.examples = examples
self.index = index # Store index for later retrieval
self._init_system_message()
def _init_system_message(self) -> None:
content = (
f"You are a {self.name} agent that judges the answers of other agents.\n"
f"You have the following information about the task:\n"
f"{self.task_info}\n\n"
f"You have the following information about the sensor data:\n"
f"{self.sensor_info}\n\n"
"Your goal is to analyze the features and "
"provide a reasoned answer with your confidence level."
)
self.set_system_message(content)
def _format_feature(self, value: Any) -> str:
"""
Format a feature value for display.
Uses scientific notation for very large or very small numbers,
and two decimal places for regular floats.
Args:
value: Feature value to format
Returns:
Formatted string representation
"""
if isinstance(value, float):
if abs(value) >= 1e4 or (abs(value) < 1e-2 and value != 0):
return f"{value:.2e}"
return f"{value:.2f}"
return str(value)
def _gen_example_info(self) -> str:
"""
Generate string representation of ICL examples.
Returns:
Formatted string with example features and labels,
or empty string if no examples provided
"""
if not self.examples or len(self.examples) == 0:
return ""
example_info = (
"**Examples**\n"
"Sensor values might not always align with your inherent "
"knowledge due to differences in data collection or processing. "
"So, we included a few labeled examples to help your interpretation:\n"
)
for example in self.examples:
example_info += f"*Example of {example['label']}*:\n"
for k, v in example["features"].items():
example_info += f" - {k}: {self._format_feature(v)}\n"
example_info += "\n"
return example_info.strip()
def _gen_feature_info(self) -> str:
"""
Generate string representation of current sample features.
Combines ICL example information with current sample features
to provide full context for classification.
Returns:
Formatted string with example and sample features
"""
feature_info = f"{self.name} features:\n"
# Add ICL example information if available
example_info = self._gen_example_info()
if example_info:
feature_info += f"{example_info}\n\n"
# Add current sample features
feature_info += "**Current sample features**:\n"
for k, v in self.sample["features"].items():
feature_info += f" - {k}: {self._format_feature(v)}\n"
return feature_info.strip()
async def interpret(self) -> str:
feature_info = self._gen_feature_info()
prompt = (
f"You have received sensor features from {self.name} modality:\n"
f"{feature_info}\n\n"
f"Please provide your answer for the task among {self.classes_info} "
"and the reasoning for your answer.\n"
"Also, please provide your confidence level for the answer as a float between 0.0 and 1.0.\n"
"Note that the sensor features might be wrong due to the data collection or processing.\n"
"You can evaluate the quality of the features by checking the examples you have.\n\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Your detailed reasoning for the classification>",\n'
' "CONFIDENCE": <Your confidence as a float between 0.0 and 1.0>,\n'
f' "ANSWER": "<Your answer among {self.classes_info}>"\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(prompt)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "CONFIDENCE", "ANSWER"]
)
# Add example index to response for queue update
if parsed_response:
parsed_response["_example_idx"] = self.index
return parsed_response

View File

@@ -1,180 +0,0 @@
"""
Majority Voting Module for Self-Consistency
Aggregates multiple LLM responses to determine the final answer via majority voting.
"""
from collections import Counter
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
import numpy as np
@dataclass
class VotingResult:
answer: str # Final answer (majority vote)
reason: str # Representative reasoning (from response with highest confidence)
avg_confidence: float # Average confidence (for majority answer)
consistency: float # Consistency score (ratio of majority votes)
vote_distribution: Dict[str, int] # Vote distribution {answer: count}
num_samples: int # Total number of valid samples
all_responses: List[Dict[str, Any]] = field(default_factory=list) # All responses
@property
def is_unanimous(self) -> bool:
"""Check if all votes are unanimous."""
return self.consistency == 1.0
@property
def majority_count(self) -> int:
"""Return the count of majority votes."""
return self.vote_distribution.get(self.answer, 0)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"answer": self.answer,
"reason": self.reason,
"avg_confidence": self.avg_confidence,
"consistency": self.consistency,
"vote_distribution": self.vote_distribution,
"num_samples": self.num_samples,
"is_unanimous": self.is_unanimous,
"majority_count": self.majority_count,
}
class MajorityVoting:
def __init__(self, valid_classes: Optional[List[str]] = None):
self.valid_classes = valid_classes
def aggregate(self, responses: List[Dict[str, Any]]) -> VotingResult:
valid_responses = self._filter_valid_responses(responses) # Filter valid responses only
if not valid_responses:
return self._empty_result()
answers = [r["ANSWER"] for r in valid_responses]
vote_counter = Counter(answers)
majority_answer = self._resolve_majority(vote_counter, valid_responses)
majority_count = vote_counter[majority_answer]
consistency = majority_count / len(valid_responses)
majority_responses = [
r for r in valid_responses if r["ANSWER"] == majority_answer
]
avg_confidence = np.mean([r["CONFIDENCE"] for r in majority_responses])
best_response = max(majority_responses, key=lambda x: x["CONFIDENCE"])
best_reason = best_response.get("REASON", "")
return VotingResult(
answer=majority_answer,
reason=best_reason,
avg_confidence=float(avg_confidence),
consistency=float(consistency),
vote_distribution=dict(vote_counter),
num_samples=len(valid_responses),
all_responses=valid_responses,
)
def _filter_valid_responses(
self,
responses: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
valid = [] # Valid responses
for r in responses:
if not isinstance(r, dict):
continue
if "ANSWER" not in r:
continue
# Check valid class
if self.valid_classes and r["ANSWER"] not in self.valid_classes:
continue
# Normalize CONFIDENCE
conf = r.get("CONFIDENCE", 0.5)
if isinstance(conf, str):
try:
conf = float(conf)
except ValueError:
conf = 0.5
r["CONFIDENCE"] = max(0.0, min(1.0, conf))
# Set default REASON
if "REASON" not in r:
r["REASON"] = ""
valid.append(r)
return valid
def _resolve_majority(
self,
vote_counter: Counter,
responses: List[Dict[str, Any]]
) -> str:
"""
Determine majority answer (use average confidence for tie-breaking).
"""
max_count = vote_counter.most_common(1)[0][1]
tied_answers = [a for a, c in vote_counter.items() if c == max_count]
if len(tied_answers) == 1:
return tied_answers[0]
# Tie-breaking: select answer with highest average confidence
best_answer = None
best_avg_conf = -1.0
for answer in tied_answers:
answer_responses = [r for r in responses if r["ANSWER"] == answer]
avg_conf = np.mean([r["CONFIDENCE"] for r in answer_responses])
if avg_conf > best_avg_conf:
best_avg_conf = avg_conf
best_answer = answer
return best_answer
def _empty_result(self) -> VotingResult:
"""Return empty result when no valid responses exist."""
return VotingResult(
answer="UNKNOWN",
reason="No valid responses received",
avg_confidence=0.0,
consistency=0.0,
vote_distribution={},
num_samples=0,
all_responses=[],
)
@staticmethod
def compute_agreement_matrix(
responses: List[Dict[str, Any]]
) -> np.ndarray:
"""
Compute pairwise agreement matrix between responses (for analysis).
Creates an N x N matrix where entry (i, j) is 1.0 if responses i and j
have the same ANSWER, and 0.0 otherwise. Useful for analyzing
consistency patterns across samples.
Args:
responses: List of response dictionaries with ANSWER keys
"""
n = len(responses)
matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
if responses[i].get("ANSWER") == responses[j].get("ANSWER"):
matrix[i, j] = 1.0
return matrix

View File

@@ -1,160 +0,0 @@
"""
HuggingFace Causal LM wrapper for inference with logits.
Loads a model via transformers, generates text, and returns both
the generated text and full logits via a forward pass.
Message format: [{"role": "user"|"assistant"|"system", "content": str}, ...]
"""
from __future__ import annotations
import asyncio
from typing import Dict, List, Optional, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# {"role": "user"|"assistant"|"system", "content": str}
ChatMessage = Dict[str, str]
# -----------------------------------------------------------------------------
# Model
# -----------------------------------------------------------------------------
class Model:
"""HuggingFace causal LM. Returns generated text + full logits."""
def __init__(
self,
model_path: str,
temperature: float = 0.0,
max_new_tokens: int = 1024,
enable_thinking: bool = False,
) -> None:
self._tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
self._model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.bfloat16,
device_map="auto",
).eval()
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self.enable_thinking = enable_thinking
def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
"""
Generate a reply and return (text, logits).
Returns:
text: Generated text string.
logits: Logits tensor (1, num_generated_tokens, vocab_size) on CPU.
"""
# Build prompt from chat messages
prompt = self._tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
inputs = self._tokenizer(prompt, return_tensors="pt").to(self._model.device)
prompt_len = inputs["input_ids"].shape[1]
# Generate with scores in a single pass
with torch.inference_mode():
gen_out = self._model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=self.temperature > 0,
temperature=self.temperature if self.temperature > 0 else 1.0,
return_dict_in_generate=True,
output_scores=True,
)
# Decode generated tokens
gen_ids = gen_out.sequences[0][prompt_len:]
text = self._tokenizer.decode(gen_ids, skip_special_tokens=True)
# Strip internal thinking tags (e.g. "analysis...assistantfinal...")
if "assistantfinal" in text:
text = text.split("assistantfinal", 1)[1].strip()
# Stack per-step scores into logits: (1, num_generated_tokens, vocab_size)
logits = torch.stack(gen_out.scores, dim=1) # scores: tuple of (batch, vocab)
return text, logits.cpu()
# -----------------------------------------------------------------------------
# Async wrappers
# -----------------------------------------------------------------------------
class AsyncModel:
"""Runs Model.invoke in a thread executor for async usage."""
def __init__(self, model: Model) -> None:
self.model = model
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None,
lambda: self.model.invoke(messages),
)
class AsyncModelPool:
"""Round-robin pool of async models."""
def __init__(self) -> None:
self.models: List[Model] = []
self._queue: Optional[asyncio.Queue[AsyncModel]] = None
def add_model(self, model: Model) -> None:
self.models.append(model)
def init_models(self) -> None:
print(f"Initializing {len(self.models)} model(s)...")
self._queue = asyncio.Queue()
for model in self.models:
self._queue.put_nowait(AsyncModel(model))
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
"""Acquire a model, invoke, release. Returns (text, logits)."""
if self._queue is None:
raise RuntimeError("Call init_models() first.")
async_model = await self._queue.get()
try:
return await async_model.invoke(messages)
finally:
self._queue.put_nowait(async_model)
# -----------------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------------
def load_models(
models: List[str],
temperature: float = 0.0,
max_new_tokens: int = 1024,
enable_thinking: bool = False,
) -> AsyncModelPool:
"""Create and initialize a model pool from a list of model paths."""
pool = AsyncModelPool()
for path in models:
pool.add_model(
Model(
path,
temperature=temperature,
max_new_tokens=max_new_tokens,
enable_thinking=enable_thinking,
)
)
pool.init_models()
return pool

View File

@@ -1,188 +0,0 @@
"""
Model Utilities for Self-Consistency Experiments
This module provides model loading and management utilities with
configurable temperature support for Self-Consistency experiments.
- Temperature-configurable model loading
- Async model pool for parallel inference
- Support for Ollama and other LangChain-compatible models
"""
import asyncio
from typing import List, Any
from langchain_ollama import ChatOllama
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, BaseMessage
def load_models_with_temperature(
models: List[str],
temperature: float = 0.0
) -> "AsyncModelPoolSC":
"""
Load models with specified temperature into an async pool.
Creates a pool of models that can be used for parallel async inference.
Each model in the pool is initialized with the same temperature setting.
Args:
models: List of model specification strings.
Supported formats:
- "ollama:url:host:port/model_name" for remote Ollama
- "ollama:model_name" for local Ollama
- Standard LangChain model strings
temperature: LLM sampling temperature (default: 0.0)
"""
model_pool = AsyncModelPoolSC()
for model_str in models:
model_pool.add_model(ModelSC(model_str, temperature=temperature))
model_pool.init_models()
return model_pool
class ModelSC:
"""
Self-Consistency model wrapper with temperature support.
Wraps LangChain chat models with configurable temperature for
Self-Consistency experiments. Supports Ollama (local and remote)
and other LangChain-compatible models.
Attributes:
model: Underlying LangChain chat model
temperature: Configured sampling temperature
"""
def __init__(self, model: str, temperature: float = 0.0):
"""
Initialize model with specified temperature.
Args:
model: Model specification string
Formats:
- "ollama:url:host:port/model" - Remote Ollama instance
- "ollama:model" - Local Ollama instance
- Other strings - Passed to LangChain init_chat_model
temperature: Sampling temperature (0.0 = deterministic)
Raises:
ValueError: If model string format is invalid
"""
self.temperature = temperature
self._model_str = model
if model.startswith("ollama:"):
self._init_ollama_model(model, temperature)
else:
self._init_langchain_model(model, temperature)
def _init_ollama_model(self, model: str, temperature: float) -> None:
"""
Initialize Ollama model (local or remote).
Args:
model: Ollama model string (ollama:url:host:port/model or ollama:model)
temperature: Sampling temperature
"""
model = model.replace("ollama:", "")
if "url:" in model:
model = model.replace("url:", "")
parts = model.split("/")
base_url = parts[0]
if not base_url.startswith("http"):
base_url = "http://" + base_url
model_type = parts[1] if len(parts) > 1 else "llama2"
self.model = ChatOllama(
model=model_type,
base_url=base_url,
temperature=temperature,
num_ctx=12000,
)
else:
self.model = ChatOllama(
model=model,
temperature=temperature,
num_ctx=12000,
)
def _init_langchain_model(self, model: str, temperature: float) -> None:
self.model = init_chat_model(
model=model,
temperature=temperature,
)
def invoke(self, messages: List[BaseMessage]) -> Any:
response = self.model.invoke(messages)
return response
def __repr__(self) -> str:
"""String representation of the model."""
return f"ModelSC(model={self._model_str}, temperature={self.temperature})"
class AsyncModelSC:
def __init__(self, model: ModelSC):
self.model = model
async def invoke(self, messages: List[BaseMessage]) -> Any:
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: self.model.invoke(messages),
)
return response
class AsyncModelPoolSC:
def __init__(self):
self.models: List[ModelSC] = []
self._available_models: asyncio.Queue = None
self._model_semaphore: asyncio.Semaphore = None
def add_model(self, model: ModelSC) -> None:
self.models.append(model)
def init_models(self) -> None:
if not self.models:
raise RuntimeError("No models added. Call add_model() first.")
self._available_models = asyncio.Queue()
for model in self.models:
async_model = AsyncModelSC(model)
self._available_models.put_nowait(async_model)
self._model_semaphore = asyncio.Semaphore(len(self.models))
def warmup(self) -> None:
print(f"[ModelPool] Warming up {len(self.models)} models...")
for i, model in enumerate(self.models):
try:
model.invoke([HumanMessage(content="Hello world!")])
print(f"[ModelPool] Model {i+1}/{len(self.models)} warmed up")
except Exception as e:
print(f"[ModelPool] Model {i+1} warmup failed: {e}")
print("[ModelPool] All models warmed up")
async def invoke(self, messages: List[BaseMessage]) -> Any:
if self._available_models is None:
raise RuntimeError("Model pool not initialized. Call init_models() first.")
async_model = await self._available_models.get()
try:
response = await async_model.invoke(messages)
return response
finally:
# Always return model to pool
self._available_models.put_nowait(async_model)
@property
def size(self) -> int:
return len(self.models)
def __repr__(self) -> str:
return f"AsyncModelPoolSC(size={self.size})"

View File

@@ -1,487 +0,0 @@
"""
RecruiterAgent: MAB-based example set selection for personalized ICL.
Components:
- self_certainty(logits): KL-div of LLM output from uniform (reward signal)
- borda_vote(results): aggregate answers across example sets
- FeatureIndex: loads index.json + embeddings.npy for O(1) lookup
- MABSelector: contextual bandit adapted from CASE/sample_case.py top_m_arm
- ArmPool: manages candidate example sets and feature vectors
- RecruiterAgent: orchestrates everything
Feature file schema:
<feature_root>/index.json + <feature_root>/embeddings.npy
index.json:
{
"version": "1.0",
"embedding_type": "sbert_metadata",
"embedding_dim": 384,
"samples": [
{"row": 0, "user_id": "001", "sample_id": 42, "label": "Hypertension", "session": "1"},
...
]
}
embeddings.npy: float32 array shape (N_samples, embedding_dim)
Switch embedding type by pointing to a different feature_root directory.
"""
from __future__ import annotations
import collections, json, math, os, random
from random import sample as random_sample
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
try:
import torch
import torch.nn.functional as F
_TORCH = True
except ImportError:
_TORCH = False
# ── Utilities ──────────────────────────────────────────────────────────────────
def self_certainty(logits) -> float:
"""KL divergence of LLM logits from uniform. Lower = more certain. Negate for MAB."""
if logits is None or not _TORCH:
return 0.0
logits = logits.squeeze(0).float()
log_probs = F.log_softmax(logits, dim=-1)
score = (-1.0 / logits.shape[-1]) * log_probs.sum(dim=-1) - math.log(logits.shape[-1])
return score.mean().item()
def borda_vote(
results: List[Tuple[int, Optional[Dict[str, Any]], float]],
valid_classes: Optional[List[str]] = None,
) -> Optional[str]:
"""
Borda-count vote over (arm_idx, response_dict, certainty_score).
Lower certainty_score = more certain = higher Borda rank.
"""
valid = [
(r[1]["ANSWER"], r[2]) for r in results
if r[1] is not None and "ANSWER" in r[1]
and (valid_classes is None or r[1]["ANSWER"] in valid_classes)
]
if not valid:
return None
sorted_v = sorted(valid, key=lambda x: x[1])
M = len(sorted_v)
borda: Dict[str, float] = {}
for rank, (ans, _) in enumerate(sorted_v):
borda[ans] = borda.get(ans, 0.0) + (M - rank)
top = max(borda.values())
cands = [a for a, s in borda.items() if s == top]
if len(cands) == 1:
return cands[0]
means = {a: float(np.mean([sc for x, sc in valid if x == a])) for a in cands}
return min(means, key=means.get)
# ── FeatureIndex ───────────────────────────────────────────────────────────────
class FeatureIndex:
"""Loads embeddings from disk; O(1) lookup by (user_id, sample_id)."""
def __init__(self, feature_root: str) -> None:
idx_p = os.path.join(feature_root, "index.json")
emb_p = os.path.join(feature_root, "embeddings.npy")
if not os.path.isfile(idx_p):
raise FileNotFoundError(f"index.json not found: {idx_p}")
if not os.path.isfile(emb_p):
raise FileNotFoundError(f"embeddings.npy not found: {emb_p}")
with open(idx_p, "r", encoding="utf-8") as f:
meta = json.load(f)
self.embedding_type: str = meta.get("embedding_type", "unknown")
self.embedding_dim: int = meta["embedding_dim"]
self.embeddings: np.ndarray = np.load(emb_p).astype(np.float32)
self.samples: List[Dict[str, Any]] = meta["samples"]
self._lookup: Dict[Tuple[str, int], int] = {
(str(e["user_id"]), int(e["sample_id"])): int(e["row"])
for e in self.samples
}
print(f"[FeatureIndex] {len(self.embeddings)} embeddings "
f"(dim={self.embedding_dim}, type={self.embedding_type})")
def get(self, user_id: str, sample_id: int) -> Optional[np.ndarray]:
row = self._lookup.get((str(user_id), int(sample_id)))
return self.embeddings[row] if row is not None else None
def get_user_embedding(self, user_id: str) -> Optional[np.ndarray]:
"""Mean embedding across all samples for a user."""
rows = [self.embeddings[e["row"]] for e in self.samples
if str(e["user_id"]) == str(user_id)]
return np.mean(rows, axis=0).astype(np.float32) if rows else None
# ── MABSelector ────────────────────────────────────────────────────────────────
class MABSelector:
"""
Contextual bandit for top-m arm identification.
Adapted from CASE / sample_case.py top_m_arm.
Key differences from original:
- Rewards injected via observe(arm, reward) — no LLM calls, no file IO
- update() does one rank-1 theta update per call
- No initialization() warm-up loop required
X: np.ndarray shape (embedding_dim, num_arms)
m: number of top arms to track (= queue_size)
"""
def __init__(
self, X: np.ndarray, m: int,
n_das: int = 5, delta: float = 0.05,
epsilon: float = 0.11, sigma: float = 0.5,
) -> None:
assert X.ndim == 2, "X must be (embedding_dim, num_arms)"
self.X = X.astype(np.float64)
self.N, self.K = X.shape
self.m = min(m, self.K)
self.n_das = n_das
self.delta = delta
self.epsilon = epsilon
self.sigma = sigma
self.arms = list(range(self.K))
self._it = -1
self._Bdi: Dict[Tuple[int, int], float] = {}
self._reset()
def _reset(self) -> None:
self.t = 0
self.rewards: List[float] = []
self.pulled_arms: List[int] = []
self.means = np.zeros(self.K)
self.na = np.zeros(self.K)
self.B_inv = np.eye(self.N, dtype=np.float64)
self.b = np.zeros(self.N, dtype=np.float64)
self.theta = np.random.normal(0, 1, size=(1, self.N))
self.J: List[int] = []
self.notJ: List[int] = list(range(self.K))
self.N_t: List[int] = []
self.best_arm: Optional[int] = None
self.worst_arm: Optional[int] = None
self.challenger: Optional[int] = None
self.c_t: Optional[int] = None
self.patience = 0
self.previous_J: List[int] = []
# ── math helpers ────────────────────────────────────────────────────────────
def _mn(self, x: np.ndarray, A: np.ndarray) -> float:
x = np.asarray(x).flatten()
return float(np.sqrt(max(float(x @ A @ x), 0.0)))
def _iu(self, A: np.ndarray, x: np.ndarray) -> np.ndarray:
"""Rank-1 inverse update: A - A x xT A / (1 + xT A x)."""
x = np.asarray(x).flatten()
d = 1.0 + self._mn(x, A) ** 2
return A - np.outer(A @ x, x @ A) / d
def _beta(self) -> float:
return math.log((math.log(max(self.t, 2)) + 1) / max(self.delta, 1e-12))
def _gap(self, i: int, j: Optional[int] = None) -> float:
xi = self.X[:, i] - (self.X[:, j] if j is not None else 0.0)
return float(self.theta.flatten() @ xi)
def _var(self, i: int, j: Optional[int] = None) -> float:
S = (self.sigma ** 2) * self.B_inv
xi = self.X[:, i] - (self.X[:, j] if j is not None else 0.0)
return self._mn(xi, S) * math.sqrt(2 * self._beta())
def _Bij(self, i: int, j: int) -> float:
t = max(self.t, 1)
if t != self._it:
self._Bdi = {}
self._it = t
k = (i, j)
if k not in self._Bdi:
self._Bdi[k] = self._gap(i, j) + self._var(i, j)
return self._Bdi[k]
def _randf(self, x: List[float], f) -> int:
arr = np.array(x, dtype=float)
val = f(arr)
c = np.argwhere(arr == val).flatten().tolist()
return random_sample(c, 1)[0]
def _mmax(self, x: List[float], m: int) -> List[int]:
x = list(x)
ids = []
for _ in range(min(m, len(x))):
idx = self._randf(x, np.max)
ids.append(idx)
x[idx] = -float("inf")
return ids
# ── sample step ─────────────────────────────────────────────────────────────
def _sel(self, nt: Optional[int]) -> Optional[int]:
if nt is None or self.c_t is None:
return None
if self.means[self.c_t] > self.means[nt]:
s = self.c_t
if s in self.N_t:
self.N_t.remove(s)
return s
return None
def _Jt(self) -> List[int]:
if self.t == 0:
return self._mmax(self.means.tolist(), self.m)
nt = self.worst_arm
sel = self._sel(nt)
self.previous_J = list(self.J)
if sel is not None:
self.J = [a for a in self.J if a != nt] + [sel]
if nt is not None and nt not in self.N_t:
self.N_t.append(nt)
order = np.argsort(self.means[self.J])[::-1].tolist()
self.J = [self.J[i] for i in order]
self.patience = 0 if self.J != self.previous_J else self.patience + 1
return self.J
def sample(self) -> List[int]:
"""Select next arm(s) to evaluate. Returns list of arm indices."""
self.J = self._Jt()
self.notJ = [a for a in self.arms if a not in self.J]
if self.t == 0 or not self.N_t:
n = min(self.n_das, len(self.notJ))
self.N_t = list(np.random.choice(self.notJ, n, replace=False)) if self.notJ else []
else:
Qt = list(np.random.choice(self.notJ, min(self.n_das, len(self.notJ)), replace=False)) if self.notJ else []
self.N_t = list({*self.N_t, *Qt} - set(self.J))
if self.N_t:
top_n = min(self.n_das, len(self.N_t))
tv = np.sort(self.means[self.N_t])[::-1][:top_n]
picked: List[int] = []
for mv in tv:
c = [a for a in self.N_t if self.means[a] == mv and a not in picked]
if c:
picked.append(random.choice(c))
self.N_t = picked
if self.N_t and self.J:
jm = self.means[self.J]
ntc = [a for a in self.J if self.means[a] == jm.min()]
self.worst_arm = random.choice(ntc)
bti = [self._Bij(a, self.worst_arm) for a in self.J]
self.best_arm = self.J[self._randf(bti, np.max)]
chi = [self._Bij(a, self.best_arm) for a in self.N_t]
self.challenger = self.N_t[self._randf(chi, np.max)]
nm = self.means[self.N_t]
c = [a for a in self.N_t if self.means[a] == nm.max()]
self.c_t = random.choice(c)
else:
self.worst_arm = self.J[0] if self.J else None
self.best_arm = self.challenger = self.c_t = None
# Greedy arm pull
if self.best_arm is not None and self.challenger is not None:
d = self.X[:, self.best_arm] - self.X[:, self.challenger]
u = [self._mn(d, self._iu(self.B_inv, self.X[:, i])) for i in self.arms]
return [self.arms[self._randf(u, np.min)]]
return [random.choice(self.arms)]
def observe(self, arm: int, reward: float) -> None:
"""Record reward for arm."""
self.rewards.append(reward)
self.pulled_arms.append(arm)
self.na[arm] += 1
self.t += 1
def update(self) -> None:
"""Rank-1 theta/B_inv update from the last observation."""
if not self.pulled_arms:
return
x = self.X[:, self.pulled_arms[-1]].flatten()
self.B_inv = self._iu(self.B_inv, x)
self.b += self.rewards[-1] * x
self.theta = (self.B_inv @ self.b).reshape(1, self.N)
self.means = (self.theta @ self.X).flatten()
def recommend(self, n: int) -> List[int]:
"""Return n arm indices with highest estimated mean rewards."""
return self._mmax(self.means.tolist(), min(n, self.K))
def stopping_rule(self) -> bool:
if None in (self.best_arm, self.worst_arm, self.challenger):
return False
return round(self._Bij(self.challenger, self.best_arm), 2) <= self.epsilon
# ── ArmPool ────────────────────────────────────────────────────────────────────
#[TODO] Add way to sample arms from the larger pool with replacement, instead of fixing on to num_arms arms.
class ArmPool:
"""
Manages candidate example sets (arms).
Each arm = list of sample dicts (N-way x k_shot).
Arm feature = mean embedding of its constituent samples.
"""
def __init__(
self,
source_samples: List[Dict[str, Any]],
classes: List[str],
k_shot: int,
feature_index: Optional[FeatureIndex],
num_arms: int,
rng: np.random.Generator,
) -> None:
self.classes = classes
self.k_shot = k_shot
self.feature_index = feature_index
self._rng = rng
self._by: Dict[str, List[Dict[str, Any]]] = {c: [] for c in classes}
for s in source_samples:
lbl = s.get("label")
if lbl in self._by:
self._by[lbl].append(s)
self.arms: List[List[Dict[str, Any]]] = []
self.arm_feats: List[Optional[np.ndarray]] = []
for _ in range(num_arms):
arm, feat = self._make()
self.arms.append(arm)
self.arm_feats.append(feat)
print(f"[ArmPool] {len(self.arms)} arms ({len(classes)}-way {k_shot}-shot)")
def _make(self) -> Tuple[List[Dict[str, Any]], Optional[np.ndarray]]:
s: List[Dict[str, Any]] = []
for c in self.classes:
pool = self._by.get(c, [])
if pool:
n = min(self.k_shot, len(pool))
idxs = self._rng.choice(len(pool), size=n, replace=False)
s.extend([pool[i] for i in idxs])
return s, self._feat(s)
def _feat(self, samples: List[Dict[str, Any]]) -> Optional[np.ndarray]:
if self.feature_index is None:
return None
vecs = []
for s in samples:
v = self.feature_index.get(
str(s.get("user_id", "")),
int(s.get("sample_id", s.get("idx", -1))),
)
if v is not None:
vecs.append(v)
return np.mean(vecs, axis=0).astype(np.float32) if vecs else None
def get_arm(self, i: int) -> List[Dict[str, Any]]:
return self.arms[i]
def build_X(self) -> np.ndarray:
"""Feature matrix X of shape (embedding_dim, num_arms). Falls back to identity."""
feats = [f for f in self.arm_feats if f is not None]
if not feats:
return np.eye(len(self.arms), dtype=np.float32)
X = np.zeros((feats[0].shape[0], len(self.arms)), dtype=np.float32)
for i, f in enumerate(self.arm_feats):
if f is not None:
X[:, i] = f
return X
# ── RecruiterAgent ────────────────────────────────────────────────────────────
class RecruiterAgent:
"""
MAB-driven recruiter for personalized ICL example set selection.
Usage:
recruiter = RecruiterAgent(source_samples, classes, ...)
initial = recruiter.recruit(queue_size) # cold start (random)
for test_sample in target_stream:
results = [(arm_idx, response_dict, score), ...]
recruiter.update(results) # MAB update lives here
new_sets = recruiter.recruit(L) # MAB-guided new sets
queue.update_with_recruiter(results, new_sets)
MAB update logic inside update():
1. reward = -score (negate: lower certainty = better = higher reward)
2. mab.observe(arm, reward)
3. mab.update() -> rank-1 theta/B_inv update
4. mab.sample() -> select next arms to explore (stored for next recruit())
"""
def __init__(
self,
source_samples: List[Dict[str, Any]],
classes: List[str],
feature_index: Optional[FeatureIndex] = None,
target_user_id: Optional[str] = None,
k_shot: int = 1,
queue_size: int = 5,
num_arms: int = 50,
n_das: int = 5,
delta: float = 0.05,
epsilon: float = 0.11,
sigma: float = 0.5,
seed: int = 42,
) -> None:
self.classes = classes
self.k_shot = k_shot
self.queue_size = queue_size
self.target_user_id = target_user_id
rng = np.random.default_rng(seed)
random.seed(seed)
np.random.seed(seed)
self.pool = ArmPool(source_samples, classes, k_shot, feature_index, num_arms, rng)
self.mab = MABSelector(
self.pool.build_X(), m=queue_size,
n_das=n_das, delta=delta, epsilon=epsilon, sigma=sigma,
)
self._next: Optional[List[int]] = None
self._initialized = False
print(f"[RecruiterAgent] {len(self.pool.arms)} arms, "
f"queue={queue_size}, classes={list(set(classes))}")
def recruit(self, n: int) -> List[Tuple[int, List[Dict[str, Any]]]]:
"""
Return n (arm_idx, example_set) pairs for insertion into the queue.
Cold start: random. Subsequent calls: MAB-recommended.
"""
if not self._initialized:
indices = random.sample(range(len(self.pool.arms)), min(n, len(self.pool.arms)))
self._initialized = True
elif self._next is not None:
indices = self._next[:n]
if len(indices) < n:
rest = [i for i in range(len(self.pool.arms)) if i not in indices]
indices += random.sample(rest, min(n - len(indices), len(rest)))
else:
indices = self.mab.recommend(n)
return [(i, self.pool.get_arm(i)) for i in indices]
def update(self, results: List[Tuple[int, Optional[Dict[str, Any]], float]]) -> None:
"""
Update MAB from (arm_idx, response_dict, self_certainty_score).
Steps:
1. reward = -score (lower certainty = better = higher reward)
2. mab.observe(arm, reward)
3. mab.update() -> rank-1 theta/B_inv update
4. mab.sample() -> selects next arms to explore
"""
if not results:
return
for arm_idx, _, score in results:
if 0 <= arm_idx < len(self.pool.arms):
self.mab.observe(arm_idx, -float(score))
self.mab.update()
try:
next_cands = self.mab.sample()
top = self.mab.recommend(self.queue_size)
self._next = list(dict.fromkeys(next_cands + top))
except Exception as e:
print(f"[RecruiterAgent] MAB sample error ({e}), using recommend()")
self._next = self.mab.recommend(self.queue_size)

View File

@@ -1,193 +0,0 @@
"""
Self-Consistency Agent for Sleep Stage Classification
This module implements an agent that uses Self-Consistency methodology:
- Sample N times with the same prompt (configurable temperature)
- Output REASON, CONFIDENCE, ANSWER in JSON format
- Aggregate final answer via Majority Voting
The agent extends the base Agent class and adds Self-Consistency
specific functionality including multi-sampling and voting.
"""
import os
import sys
import json
import asyncio
from typing import List, Dict, Any, Optional
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from sc.core.agent import Agent
class SCAgent(Agent):
"""
Self-Consistency based sleep stage classification agent.
This agent samples N times for the same input and uses majority voting
to determine the final answer. It supports confidence-based tie-breaking
and provides detailed metrics about consistency.
Attributes:
task_info: Description of the classification task
classes_info: List of valid class labels
sensor_info: Information about sensor data
sample: Current sample being classified
examples: ICL examples for context
voter: MajorityVoting instance for aggregation
"""
def __init__(
self,
name: str,
index: int,
model_pool,
task_info: str,
classes_info: List[str],
sensor_info: str,
sample: Dict[str, Any],
examples: List[Dict[str, Any]],
log_path: str,
):
"""
Initialize Self-Consistency Agent.
Args:
name: Agent identifier name
index: Index of the agent in the dataset
model_pool: Async model pool for LLM inference
task_info: Description of the classification task
classes_info: List of valid class labels
sensor_info: Information about sensor data format
sample: Sample to be classified
examples: ICL examples for few-shot learning
log_path: Directory path for saving logs
"""
super().__init__(
name=name + f"_{index}",
model_pool=model_pool,
log_path=log_path,
)
self.task_info = task_info
self.classes_info = classes_info
self.sensor_info = sensor_info
self.sample = sample
self.examples = examples
self.index = index # Store index for later retrieval
self._init_system_message()
def _init_system_message(self) -> None:
"""Initialize the system message that defines the agent's role."""
content = (
f"You are a {self.name} agent that interprets sensor data to solve a task.\n"
f"You have the following information about the task:\n"
f"{self.task_info}\n\n"
f"You have the following information about the sensor data:\n"
f"{self.sensor_info}\n\n"
"Your goal is to analyze the features and "
"provide a reasoned answer with your confidence level."
)
self.set_system_message(content)
def _format_feature(self, value: Any) -> str:
"""
Format a feature value for display.
Uses scientific notation for very large or very small numbers,
and two decimal places for regular floats.
Args:
value: Feature value to format
Returns:
Formatted string representation
"""
if isinstance(value, float):
if abs(value) >= 1e4 or (abs(value) < 1e-2 and value != 0):
return f"{value:.2e}"
return f"{value:.2f}"
return str(value)
def _gen_example_info(self) -> str:
"""
Generate string representation of ICL examples.
Returns:
Formatted string with example features and labels,
or empty string if no examples provided
"""
if not self.examples or len(self.examples) == 0:
return ""
example_info = (
"**Examples**\n"
"Sensor values might not always align with your inherent "
"knowledge due to differences in data collection or processing. "
"So, we included a few labeled examples to help your interpretation:\n"
)
for example in self.examples:
example_info += f"*Example of {example['label']}*:\n"
for k, v in example["features"].items():
example_info += f" - {k}: {self._format_feature(v)}\n"
example_info += "\n"
return example_info.strip()
def _gen_feature_info(self) -> str:
"""
Generate string representation of current sample features.
Combines ICL example information with current sample features
to provide full context for classification.
Returns:
Formatted string with example and sample features
"""
feature_info = f"{self.name} features:\n"
# Add ICL example information if available
example_info = self._gen_example_info()
if example_info:
feature_info += f"{example_info}\n\n"
# Add current sample features
feature_info += "**Current sample features**:\n"
for k, v in self.sample["features"].items():
feature_info += f" - {k}: {self._format_feature(v)}\n"
return feature_info.strip()
async def interpret(self) -> str:
feature_info = self._gen_feature_info()
prompt = (
f"You have received sensor features from {self.name} modality:\n"
f"{feature_info}\n\n"
f"Please provide your answer for the task among {self.classes_info} "
"and the reasoning for your answer.\n"
"Also, please provide your confidence level for the answer as a float between 0.0 and 1.0.\n"
"Note that the sensor features might be wrong due to the data collection or processing.\n"
"You can evaluate the quality of the features by checking the examples you have.\n\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Your detailed reasoning for the classification>",\n'
' "CONFIDENCE": <Your confidence as a float between 0.0 and 1.0>,\n'
f' "ANSWER": "<Your answer among {self.classes_info}>"\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(prompt)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "CONFIDENCE", "ANSWER"]
)
# Add example index to response for queue update
if parsed_response:
parsed_response["_example_idx"] = self.index
return parsed_response

View File

@@ -1,426 +0,0 @@
def _enabled(config=None):
if config is None:
return True
return config.get("debug", True)
def log(message, config=None):
if not _enabled(config):
return
print(message)
def warn_no_examples():
print("[WARN] No examples found for dataloader. Skipping task.")
def log_queue_state_before(processed_count, user_info, ex_queue, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print(
f"[Sample {processed_count}] {user_info} - Queue State BEFORE Processing "
f"(Instance ID: {ex_queue.get_instance_id()}):"
)
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
print(
f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})"
)
print(f"{'#'*60}\n")
def warn_no_agents(processed_count):
print(f"[WARN] No agents added for sample {processed_count}. Skipping.")
def warn_interpretation_failed(processed_count):
print(f"[WARN] Interpretation failed for sample {processed_count}. Skipping.")
def log_tracking(
processed_count,
user_info,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
avg_confidence_so_far,
config=None,
):
if not _enabled(config):
return
print(f"\n{'='*60}")
print(f"[TRACKING] Sample {processed_count} | {user_info}")
print(
f" Answer: {answer} | GT: {ground_truth} | "
f"{'✓ CORRECT' if is_correct else '✗ WRONG'}"
)
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
print(" ─────────────────────────────────────────────────────")
print(
f" Cumulative Accuracy: {cumulative_accuracy:.4f} "
f"({cumulative_correct}/{processed_count + 1})"
)
print(
f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}"
)
print(f" Avg Confidence (so far): {avg_confidence_so_far:.4f}")
print(f"{'='*60}\n")
def log_queue_state_after(processed_count, ex_queue, config=None):
if not _enabled(config):
return
print(f"\n[Sample {processed_count}] Queue State AFTER Update:")
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
print(
f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})"
)
print()
def warn_no_responses(processed_count):
print(f"[WARN] No responses returned, falling back to basic update.")
def log_final_queue_stats(user_info, survival_summary, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print(f"[FINAL] Queue Survival Statistics for {user_info}")
print(f" Total Evicted Cases: {survival_summary['total_evicted']}")
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
print(f" Max Survival: {survival_summary['max_survival']} samples")
print(f" Min Survival: {survival_summary['min_survival']} samples")
print(f" Avg Usage Count: {survival_summary['avg_usage']:.2f}")
print(f" Max Usage Count: {survival_summary['max_usage']}")
print(f"{'#'*60}\n")
def warn_no_user_dirs(data_path):
print(f"[WARN] No user directories found in {data_path}")
def log_found_users(users, config=None):
if not _enabled(config):
return
print(
f"[INFO] Found {len(users)} users: {users[:5]}"
f"{'...' if len(users) > 5 else ''}"
)
def warn_skip_user_no_test_data(user):
print(f"[WARN] Skipping user {user} - no test data available")
def warn_skip_user_no_example_data(user):
print(f"[WARN] Skipping user {user} - no example data available")
def log_main_loading_config(config_path, config=None):
if not _enabled(config):
return
print(f"[MAIN] Loading config: {config_path}")
def log_main_config(config, config_enabled=True):
if not config_enabled:
return
print("=" * 60)
print("SELF-CONSISTENCY EXPERIMENT CONFIGURATION")
print("=" * 60)
print(f" Data path: {config.get('data_path', 'N/A')}")
print(f" Log path: {config.get('log_path', 'N/A')}")
print(f" Num ICL examples: {config.get('num_examples', 1)}")
print(f" Num seeds: {config.get('num_seeds', 1)}")
print(f" Num SC samples: {config.get('num_sc_samples', 5)}")
print(f" Temperature: {config.get('temperature', 0.0)}")
print(f" Sample rate: 1/{config.get('sample_rate', 10)}")
print(f" Num models: {len(config.get('models', []))}")
print("=" * 60)
def log_main_start(config=None):
if not _enabled(config):
return
print("[MAIN] Starting experiments...")
def error_task_failed(result):
print(f"[ERROR] Task failed with exception: {result}")
def log_total_results(count, config=None):
if not _enabled(config):
return
print(f"[MAIN] Total results collected: {count}")
def log_experiment_results(stats, config=None):
if not _enabled(config):
return
print("\n" + "=" * 60)
print("EXPERIMENT RESULTS")
print("=" * 60)
print(f" Total samples: {stats.get('total_samples', 0)}")
print(f" Accuracy: {stats.get('accuracy', 0):.4f}")
print(f" Avg Confidence: {stats.get('avg_confidence', 0):.4f}")
print(f" Avg Consistency: {stats.get('avg_consistency', 0):.4f}")
print(
" High Consistency (>=0.8) Accuracy: "
f"{stats.get('high_consistency_accuracy', 0):.4f}"
)
print(f" High Consistency Samples: {stats.get('high_consistency_samples', 0)}")
print("\n Class-wise Accuracy:")
for cls, acc in stats.get("class_accuracy", {}).items():
print(f" {cls}: {acc:.4f}")
def log_temporal_analysis(temporal, config=None):
if not _enabled(config):
return
if not temporal:
return
print("\n" + "-" * 60)
print(" TEMPORAL ANALYSIS (Caching Effect)")
print("-" * 60)
print(f" First Half Accuracy: {temporal.get('first_half_accuracy', 0):.4f}")
print(f" Second Half Accuracy: {temporal.get('second_half_accuracy', 0):.4f}")
improvement = temporal.get("accuracy_improvement", 0)
improvement_sign = "+" if improvement >= 0 else ""
print(f" Improvement: {improvement_sign}{improvement:.4f}")
quartiles = temporal.get("quartile_accuracies", [])
if quartiles:
print(
f" Quartile Accuracies: Q1={quartiles[0]:.4f}"
+ (f", Q2={quartiles[1]:.4f}" if len(quartiles) > 1 else "")
+ (f", Q3={quartiles[2]:.4f}" if len(quartiles) > 2 else "")
+ (f", Q4={quartiles[3]:.4f}" if len(quartiles) > 3 else "")
)
print(f"\n First Half Confidence: {temporal.get('first_half_confidence', 0):.4f}")
print(f" Second Half Confidence: {temporal.get('second_half_confidence', 0):.4f}")
conf_improvement = temporal.get("confidence_improvement", 0)
conf_sign = "+" if conf_improvement >= 0 else ""
print(f" Confidence Change: {conf_sign}{conf_improvement:.4f}")
def log_queue_stats(queue_stats, config=None):
if not _enabled(config):
return
if not queue_stats:
return
print("\n" + "-" * 60)
print(" QUEUE SURVIVAL STATISTICS")
print("-" * 60)
print(f" Total Evicted Cases: {queue_stats.get('total_evicted', 0)}")
print(f" Avg Survival: {queue_stats.get('avg_survival', 0):.2f} samples")
print(f" Max Survival: {queue_stats.get('max_survival', 0)} samples")
print(f" Min Survival: {queue_stats.get('min_survival', 0)} samples")
print(f" Avg Usage Count: {queue_stats.get('avg_usage', 0):.2f}")
print(f" Max Usage Count: {queue_stats.get('max_usage', 0)}")
print("=" * 60)
def log_save_statistics(stats_path, config=None):
if not _enabled(config):
return
print(f"[SAVE] Statistics saved to: {stats_path}")
def log_save_results(results_path, config=None):
if not _enabled(config):
return
print(f"[SAVE] Results saved to: {results_path}")
def log_save_config(config_path, config=None):
if not _enabled(config):
return
print(f"[SAVE] Config saved to: {config_path}")
def log_results_saved(log_path, config=None):
if not _enabled(config):
return
print(f"[MAIN] Results saved to: {log_path}")
def log_policy_experiment_header(policy_label, user_id, shuffle_seed, total_samples, queue_size, config=None, policy_note=None):
if not _enabled(config):
return
print(f"\n{'='*80}")
print(f"{policy_label} QUEUE POLICY EXPERIMENT")
print(f"User: {user_id} | Shuffle Seed: {shuffle_seed}")
print(f"Total samples: {total_samples} | Queue size: {queue_size}")
if policy_note:
print(policy_note)
print(f"{'='*80}\n")
def log_policy_queue_state(policy_label, processed_count, user_id, ex_queue, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print(f"[Sample {processed_count}] User {user_id} | {policy_label} Policy")
print("Queue State BEFORE Processing:")
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
print(f"{'#'*60}\n")
def log_policy_result(policy_label, processed_count, answer, ground_truth, is_correct, avg_confidence, consistency, cumulative_accuracy, cumulative_correct, window_accuracy, recent_results, config=None):
if not _enabled(config):
return
print(f"\n{'='*60}")
print(f"[RESULT] Sample {processed_count} | {policy_label} Policy")
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
print(f"{'='*60}\n")
def log_confidence_map(responses, confidence_map, config=None):
if not _enabled(config):
return
print("\n[CONFIDENCE MAP] Per-agent confidence scores:")
for idx, conf in sorted(confidence_map.items()):
ans = responses[idx].get("ANSWER", "?")
print(f" Queue[{idx}]: answer={ans}, confidence={conf:.4f}")
def log_consistency_map(responses, consistency_map, config=None):
if not _enabled(config):
return
print("\n[CONSISTENCY MAP] Per-agent consistency scores:")
for idx, cons in sorted(consistency_map.items()):
ans = responses[idx].get("ANSWER", "?")
print(f" Queue[{idx}]: answer={ans}, consistency={cons:.4f}")
def log_policy_queue_state_after(policy_label, processed_count, ex_queue, config=None):
if not _enabled(config):
return
print(f"\n[Sample {processed_count}] Queue State AFTER {policy_label} Update:")
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
def log_final_policy_survival(policy_label, user_id, shuffle_seed, survival_summary, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print(f"[FINAL] Queue Survival Statistics - {policy_label} Policy")
print(f" User: {user_id} | Seed: {shuffle_seed}")
print(f" Total Evicted: {survival_summary['total_evicted']}")
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
print(f" Avg Usage: {survival_summary['avg_usage']:.2f}")
print(f"{'#'*60}\n")
def log_policy_main_header(policy_label, user_id, shuffle_seed, queue_size, num_models, log_path, config=None, policy_note=None):
if not _enabled(config):
return
print("=" * 80)
print(f"{policy_label} QUEUE POLICY EXPERIMENT")
print("=" * 80)
print(f" User ID: {user_id}")
print(f" Shuffle Seed: {shuffle_seed}")
print(f" Queue Size: {queue_size}")
print(f" SC Samples (Agents): {num_models}")
print(f" Log Path: {log_path}")
if policy_note:
print(policy_note)
print("=" * 80)
def log_policy_loading_data(user_id, shuffle_seed, config=None):
if not _enabled(config):
return
print(f"\n[MAIN] Loading shuffled data for user {user_id}, seed {shuffle_seed}...")
def log_policy_loading_models(config=None):
if not _enabled(config):
return
print("\n[MAIN] Loading models...")
def log_policy_start(config=None, label="experiment"):
if not _enabled(config):
return
print(f"\n[MAIN] Starting {label}...")
def log_policy_complete_summary(policy_label, user_id, shuffle_seed, stats, stage_accuracy, stage_counts, temporal, config=None, expected_note=None):
if not _enabled(config):
return
print("\n" + "=" * 80)
print(f"EXPERIMENT COMPLETE - {policy_label} POLICY")
print("=" * 80)
print(f" User: {user_id} | Seed: {shuffle_seed}")
print(f" Total Samples: {stats.get('total_samples', 0)}")
print(f" Overall Accuracy: {stats.get('accuracy', 0):.4f}")
print(f" Macro F1: {stats.get('macro_f1', 0):.4f}")
print("\n Per-Stage Accuracy:")
for stage, acc in stage_accuracy.items():
count = stage_counts.get(stage, 0)
print(f" {stage}: {acc:.4f} (n={count})")
print("\n Temporal Analysis:")
print(f" First Half: {temporal.get('first_half_accuracy', 0):.4f}")
print(f" Second Half: {temporal.get('second_half_accuracy', 0):.4f}")
print(f" Improvement: {temporal.get('improvement', 0):+.4f}")
if expected_note:
print(expected_note)
print("=" * 80)
def log_queue_random_stats(user_id, shuffle_seed, sampler_stats, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print("[FINAL] Queue Random Statistics")
print(f" User: {user_id} | Seed: {shuffle_seed}")
print(f" Total Steps: {sampler_stats['total_steps']}")
print(f" Total Refreshed: {sampler_stats['total_refreshed']} example sets")
print(f" Avg Refresh per Step: {sampler_stats['avg_refresh_per_step']} (always full)")
print(f"{'#'*60}\n")
def log_queue_random_sampler_init(example_count, classes, queue_size, config=None):
if not _enabled(config):
return
print(f"[QueueRandomSampler] Initialized with {example_count} examples")
print(f" Classes: {classes}")
print(f" Queue size: {queue_size}")
print(" Policy: ALL elements refreshed every step")
def log_queue_random_queue_state(processed_count, user_id, queue_sampler, config=None):
if not _enabled(config):
return
print(f"\n{'#'*60}")
print(f"[Sample {processed_count}] User {user_id} | QUEUE RANDOM Policy")
print("Queue State (ALL FRESH RANDOM samples):")
for idx, ex_idcs in enumerate(queue_sampler):
print(f" [{idx}] Example indices: {ex_idcs}")
print(f"{'#'*60}\n")
def log_queue_random_result(processed_count, answer, ground_truth, is_correct, avg_confidence, consistency, cumulative_accuracy, cumulative_correct, window_accuracy, recent_results, config=None):
if not _enabled(config):
return
print(f"\n{'='*60}")
print(f"[RESULT] Sample {processed_count} | QUEUE RANDOM Policy")
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
print(" [NOTE] Queue will be FULLY REFRESHED for next sample")
print(f"{'='*60}\n")

View File

@@ -1,222 +0,0 @@
"""
HTTP API for HuggingFace causal LM inference (answer + logits).
This module exposes a small FastAPI service that:
- Loads a local HuggingFace CausalLM once at startup
- Accepts chat-style messages via HTTP
- Returns generated text
- Returns *logits-derived* information in a practical size:
- top-k logprobs for each generated token (recommended)
- optional prompt logits top-k for the final prompt position
Why not return full logits?
Full logits are extremely large (seq_len x vocab_size) and will quickly
overwhelm network/memory. This API defaults to returning top-k logprobs.
Run:
MODEL_DIR=/path/to/model \\
uvicorn sc.hf_api:app --host 0.0.0.0 --port 8000
Example:
curl -X POST http://localhost:8000/generate \\
-H 'Content-Type: application/json' \\
-d '{"messages":[{"role":"user","content":"Hello!"}],"max_new_tokens":64,"top_k":10}'
"""
from __future__ import annotations
import os
from typing import Any, Dict, List, Literal, Optional
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
DEFAULT_MODEL_DIR = os.environ.get(
"MODEL_DIR", "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
)
app = FastAPI(title="HF LLM API", version="0.1.0")
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
class GenerateRequest(BaseModel):
messages: List[ChatMessage]
max_new_tokens: int = Field(default=128, ge=1, le=1024)
temperature: float = Field(default=0.0, ge=0.0, le=2.0)
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
do_sample: bool = False
top_k: int = Field(default=20, ge=1, le=200)
# If True, also returns prompt last-position top-k logits (not full matrix)
include_prompt_topk: bool = False
class TokenTopK(BaseModel):
token_id: int
token: str
logprob: float
class GeneratedStep(BaseModel):
token_id: int
token: str
logprob: float
topk: List[TokenTopK]
class PromptTopK(BaseModel):
position: int
topk: List[TokenTopK]
class GenerateResponse(BaseModel):
prompt: str
generated_text: str
generated_token_ids: List[int]
steps: List[GeneratedStep]
prompt_topk: Optional[PromptTopK] = None
def _get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def _load_model_and_tokenizer(model_dir: str):
device = _get_device()
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
).eval()
return tokenizer, model, device
@app.on_event("startup")
def _startup_load():
# Load once; shared across requests.
global tokenizer, model, device
tokenizer, model, device = _load_model_and_tokenizer(DEFAULT_MODEL_DIR)
app.state.model_dir = DEFAULT_MODEL_DIR
app.state.device = device
@app.get("/health")
def health() -> Dict[str, Any]:
return {
"ok": True,
"model_dir": getattr(app.state, "model_dir", None),
"device": getattr(app.state, "device", None),
}
def _apply_chat_template(messages: List[ChatMessage]) -> str:
# Convert pydantic objects to plain dicts compatible with HF template.
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
try:
prompt = tokenizer.apply_chat_template(
msg_dicts,
tokenize=False,
add_generation_prompt=True,
)
except Exception:
# Fallback: naive concatenation
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in msg_dicts]) + "\nassistant:"
return prompt
def _topk_from_logits(logits_1d: torch.Tensor, top_k: int) -> List[TokenTopK]:
# logits_1d: (vocab,)
top_vals, top_ids = torch.topk(logits_1d, k=top_k)
# Convert to logprobs for interpretability
logprobs = torch.log_softmax(logits_1d, dim=-1)
out: List[TokenTopK] = []
for tid in top_ids.tolist():
tok = tokenizer.decode([tid])
out.append(
TokenTopK(
token_id=int(tid),
token=tok,
logprob=float(logprobs[tid].detach().cpu().item()),
)
)
# Sort in descending logprob (topk preserves order, but be explicit)
out.sort(key=lambda x: x.logprob, reverse=True)
return out
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest) -> GenerateResponse:
if not hasattr(app.state, "model_dir"):
raise HTTPException(status_code=503, detail="Model not loaded yet")
prompt = _apply_chat_template(req.messages)
inputs = tokenizer(prompt, return_tensors="pt")
if device == "cuda":
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Use generate() so we can get per-step scores (logits).
# output.scores is a list with length = generated_tokens
# each element shape: (batch, vocab)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=req.do_sample if req.temperature > 0 else False,
temperature=req.temperature if req.temperature > 0 else 1.0,
top_p=req.top_p,
return_dict_in_generate=True,
output_scores=True,
)
# Generated token ids include prompt + new tokens
seq = out.sequences[0]
prompt_len = int(inputs["input_ids"].shape[1])
gen_token_ids = seq[prompt_len:].tolist()
generated_text = tokenizer.decode(gen_token_ids, skip_special_tokens=True)
# Build per-step top-k + chosen token logprob
steps: List[GeneratedStep] = []
if out.scores is None:
raise HTTPException(status_code=500, detail="Model did not return scores")
for step_idx, step_logits in enumerate(out.scores):
# step_logits: (1, vocab)
step_logits_1d = step_logits[0]
chosen_id = int(gen_token_ids[step_idx]) if step_idx < len(gen_token_ids) else None
logprobs_1d = torch.log_softmax(step_logits_1d, dim=-1)
chosen_logprob = float(logprobs_1d[chosen_id].detach().cpu().item()) if chosen_id is not None else float("nan")
steps.append(
GeneratedStep(
token_id=chosen_id,
token=tokenizer.decode([chosen_id]),
logprob=chosen_logprob,
topk=_topk_from_logits(step_logits_1d, req.top_k),
)
)
prompt_topk: Optional[PromptTopK] = None
if req.include_prompt_topk:
with torch.no_grad():
forward = model(**inputs)
# forward.logits: (1, seq_len, vocab)
last_pos = int(forward.logits.shape[1] - 1)
last_logits = forward.logits[0, -1, :]
prompt_topk = PromptTopK(position=last_pos, topk=_topk_from_logits(last_logits, req.top_k))
return GenerateResponse(
prompt=prompt,
generated_text=generated_text,
generated_token_ids=[int(t) for t in gen_token_ids],
steps=steps,
prompt_topk=prompt_topk,
)

View File

@@ -1,20 +0,0 @@
import os
import yaml
from datetime import datetime
class Logger:
def __init__(self, log_path: str):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.log_path = f"{log_path}_{timestamp}"
os.makedirs(self.log_path, exist_ok=True)
def log(self, message: str):
print(message)
with open(self.log_path, "a", encoding="utf-8") as f:
f.write(message + "\n")
def log_config(self, config: dict):
dumped = yaml.dump(config, default_flow_style=False, sort_keys=False)
msg = "Config:\n" + dumped
self.log(msg)

View File

@@ -1,149 +0,0 @@
"""
HuggingFace Causal LM wrapper for inference with logits.
Loads a model via transformers, generates text, and returns both
the generated text and full logits via a forward pass.
Message format: [{"role": "user"|"assistant"|"system", "content": str}, ...]
"""
from __future__ import annotations
import asyncio
from typing import Dict, List, Optional, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# {"role": "user"|"assistant"|"system", "content": str}
ChatMessage = Dict[str, str]
# -----------------------------------------------------------------------------
# Model
# -----------------------------------------------------------------------------
class Model:
"""HuggingFace causal LM. Returns generated text + full logits."""
def __init__(self, model_path: str, temperature: float = 0.0, max_new_tokens: int = 1024) -> None:
self._tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
self._model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()
self.temperature = temperature
self.max_new_tokens = max_new_tokens
def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
"""
Generate a reply and return (text, logits).
Returns:
text: Generated text string.
logits: Full logits tensor (1, seq_len, vocab_size) on CPU.
seq_len covers prompt + generated tokens.
"""
# Build prompt from chat messages
prompt = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
inputs = self._tokenizer(prompt, return_tensors="pt").to(self._model.device)
prompt_len = inputs["input_ids"].shape[1]
# Generate
with torch.no_grad():
gen_out = self._model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=self.temperature > 0,
temperature=self.temperature if self.temperature > 0 else 1.0,
)
# Decode generated tokens
gen_ids = gen_out[0][prompt_len:]
text = self._tokenizer.decode(gen_ids, skip_special_tokens=True)
# Forward pass on full sequence to get logits
with torch.no_grad():
logits = self._model(gen_out).logits # (1, seq_len, vocab_size)
return text, logits.cpu()
# -----------------------------------------------------------------------------
# Async wrappers
# -----------------------------------------------------------------------------
class AsyncModel:
"""Runs Model.invoke in a thread executor for async usage."""
def __init__(self, model: Model) -> None:
self.model = model
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None, lambda: self.model.invoke(messages),
)
class AsyncModelPool:
"""Round-robin pool of async models."""
def __init__(self) -> None:
self.models: List[Model] = []
self._queue: Optional[asyncio.Queue[AsyncModel]] = None
def add_model(self, model: Model) -> None:
self.models.append(model)
def init_models(self) -> None:
print(f"Initializing {len(self.models)} model(s)...")
self._queue = asyncio.Queue()
for model in self.models:
self._queue.put_nowait(AsyncModel(model))
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
"""Acquire a model, invoke, release. Returns (text, logits)."""
if self._queue is None:
raise RuntimeError("Call init_models() first.")
async_model = await self._queue.get()
try:
return await async_model.invoke(messages)
finally:
self._queue.put_nowait(async_model)
# -----------------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------------
def load_models(
models: List[str],
temperature: float = 0.0,
max_new_tokens: int = 1024,
) -> AsyncModelPool:
"""Create and initialize a model pool from a list of model paths."""
pool = AsyncModelPool()
for path in models:
pool.add_model(
Model(path, temperature=temperature, max_new_tokens=max_new_tokens)
)
pool.init_models()
return pool
# test
model = Model("/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b")
messages = [
{"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
]
text, logits = model.invoke(messages)
print(text)
print(logits.shape)

View File

@@ -1 +0,0 @@
# SC Preprocess Module

View File

@@ -1,99 +0,0 @@
"""
Build unified feature index from PPGBP per-subject .npy embeddings.
Output: <output_dir>/index.json + <output_dir>/embeddings.npy
Usage:
python -m sc.preprocess.build_feature_index --embedding_root $ROOT --output_dir ./features/sbert_metadata
"""
from __future__ import annotations
import argparse
import json
import os
from typing import Any, Dict, List, Optional
import numpy as np
def _load_labels(base_dir: Optional[str]) -> Dict[int, str]:
"""Load subject_id -> label from xlsx. Prefers Hypertension column (PPGBP), else class/Class."""
if not base_dir:
return {}
try:
import pandas as pd
except ImportError:
return {}
xlsx = None
for f in os.listdir(base_dir):
if f.endswith(".xlsx"):
xlsx = os.path.join(base_dir, f)
break
if not xlsx or not os.path.isfile(xlsx):
return {}
df = pd.read_excel(xlsx, header=1)
if "subject_ID" not in df.columns:
return {}
label_col = None
for c in ("Hypertension", "class", "Class"):
if c in df.columns:
label_col = c
break
out = {}
for _, row in df.iterrows():
sid = int(row["subject_ID"])
out[sid] = str(row.get(label_col, "unknown")).strip() if label_col else "unknown"
return out
def build_index(
embedding_root: str,
output_dir: str,
embedding_type: str = "sbert_metadata",
base_dir: Optional[str] = None,
) -> None:
if not os.path.isdir(embedding_root):
raise FileNotFoundError(embedding_root)
subject_ids = []
for f in sorted(os.listdir(embedding_root)):
if not f.endswith(".npy"):
continue
try:
subject_ids.append(int(f.replace(".npy", "")))
except ValueError:
continue
subject_ids = sorted(subject_ids)
if not subject_ids:
raise ValueError("No .npy files in " + embedding_root)
labels = _load_labels(base_dir)
embeddings_list = []
samples: List[Dict[str, Any]] = []
for row, sid in enumerate(subject_ids):
emb = np.load(os.path.join(embedding_root, f"{sid}.npy")).astype(np.float32)
embeddings_list.append(emb)
samples.append({
"row": row,
"user_id": str(sid),
"sample_id": sid,
"label": labels.get(sid, "unknown"),
"session": "1",
})
embeddings = np.stack(embeddings_list, axis=0)
dim = int(embeddings.shape[1])
meta = {"version": "1.0", "embedding_type": embedding_type, "embedding_dim": dim, "samples": samples}
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "index.json"), "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
np.save(os.path.join(output_dir, "embeddings.npy"), embeddings)
print(f"[build_feature_index] Wrote {len(samples)} samples, dim={dim}")
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--embedding_root", required=True, help="Directory of per-subject .npy files")
ap.add_argument("--output_dir", required=True, help="Output dir for index.json and embeddings.npy")
ap.add_argument("--embedding_type", default="sbert_metadata")
ap.add_argument("--base_dir", default=None, help="Optional PPGBP data dir (xlsx) for labels")
a = ap.parse_args()
build_index(a.embedding_root, a.output_dir, a.embedding_type, a.base_dir)
if __name__ == "__main__":
main()

View File

@@ -1,134 +0,0 @@
"""
Build PPGBP embeddings under PPGBP_EMBEDDINGS_ROOT with two top-level dirs:
metadata_embs/<emb_type>/<subject_id>.npy
sig_embs/<emb_type>/<subject_id>.npy
For now only metadata embeddings are produced: for each subject, take the xlsx row,
vectorize (one-hot for Sex and Hypertension, numeric for the rest) and save.
Usage:
python -m sc.preprocess.build_ppgbp_embeddings \\
--embeddings_root /path/to/ppgbp_embeddings \\
--data_dir /path/to/ppgbp_data \\
--metadata_emb_type onehot
Then build the feature index (for downstream RecruiterAgent etc.):
python -m sc.preprocess.build_feature_index \\
--embedding_root $PPGBP_EMBEDDINGS_ROOT/metadata_embs/onehot \\
--output_dir $DATA_INDEX_ROOT/metadata_onehot \\
--embedding_type metadata_onehot \\
--base_dir $PPGBP_DATA_DIR
"""
from __future__ import annotations
import argparse
import os
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
# Xlsx column names (header=1)
SUBJECT_ID_COL = "subject_ID"
SEX_COL = "Sex(M/F)"
AGE_COL = "Age(year)"
HEIGHT_COL = "Height(cm)"
WEIGHT_COL = "Weight(kg)"
SBP_COL = "Systolic Blood Pressure(mmHg)"
DBP_COL = "Diastolic Blood Pressure(mmHg)"
HR_COL = "Heart Rate(b/m)"
BMI_COL = "BMI(kg/m^2)"
HYP_COL = "Hypertension"
NUMERIC_COLS = [AGE_COL, HEIGHT_COL, WEIGHT_COL, SBP_COL, DBP_COL, HR_COL, BMI_COL]
CATEGORICAL_COLS = [SEX_COL, HYP_COL]
def _find_xlsx(data_dir: str) -> str:
for f in os.listdir(data_dir):
if f.endswith(".xlsx"):
return os.path.join(data_dir, f)
raise FileNotFoundError(f"No .xlsx in {data_dir}")
def _onehot_and_numeric(df: pd.DataFrame) -> Tuple[np.ndarray, List[str], Dict[str, List[str]]]:
"""
One-hot encode categoricals (Sex, Hypertension) and stack with numeric columns.
Returns (matrix, column_names, categories_used).
"""
out_cols: List[str] = []
pieces: List[np.ndarray] = []
categories_used: Dict[str, List[str]] = {}
for col in CATEGORICAL_COLS:
if col not in df.columns:
continue
vals = df[col].fillna("unknown").astype(str).str.strip()
uniq = sorted(vals.unique())
categories_used[col] = uniq
for u in uniq:
out_cols.append(f"{col}__{u}")
onehot = (vals.values.reshape(-1, 1) == np.array(uniq)).astype(np.float32)
pieces.append(onehot)
for col in NUMERIC_COLS:
if col not in df.columns:
continue
out_cols.append(col)
vec = df[col].fillna(0).astype(np.float32).values.reshape(-1, 1)
pieces.append(vec)
if not pieces:
raise ValueError("No columns found; check xlsx has expected column names")
X = np.hstack(pieces)
return X, out_cols, categories_used
def create_embedding_dirs(embeddings_root: str, metadata_emb_type: str, sig_emb_type: str = "placeholder") -> None:
"""Create metadata_embs/<emb_type>/ and sig_embs/<emb_type>/."""
os.makedirs(embeddings_root, exist_ok=True)
meta_dir = os.path.join(embeddings_root, "metadata_embs", metadata_emb_type)
sig_dir = os.path.join(embeddings_root, "sig_embs", sig_emb_type)
os.makedirs(meta_dir, exist_ok=True)
os.makedirs(sig_dir, exist_ok=True)
def build_metadata_embeddings_onehot(
data_dir: str,
embeddings_root: str,
metadata_emb_type: str = "onehot",
) -> None:
"""
Read xlsx from data_dir, vectorize each subject row (one-hot + numeric), save as
metadata_embs/<metadata_emb_type>/<subject_id>.npy.
"""
xlsx_path = _find_xlsx(data_dir)
df = pd.read_excel(xlsx_path, header=1)
if SUBJECT_ID_COL not in df.columns:
raise ValueError(f"xlsx must have column {SUBJECT_ID_COL}")
X, _cols, _cats = _onehot_and_numeric(df)
subject_ids = df[SUBJECT_ID_COL].astype(int).values
out_dir = os.path.join(embeddings_root, "metadata_embs", metadata_emb_type)
os.makedirs(out_dir, exist_ok=True)
for i, sid in enumerate(subject_ids):
path = os.path.join(out_dir, f"{int(sid)}.npy")
np.save(path, X[i].astype(np.float32))
print(f"[build_ppgbp_embeddings] Wrote {len(subject_ids)} metadata embeddings to {out_dir} (dim={X.shape[1]})")
def main() -> None:
ap = argparse.ArgumentParser(description="Build PPGBP metadata (and optionally signal) embeddings under embeddings_root.")
ap.add_argument("--embeddings_root", required=True, help="Root for metadata_embs/ and sig_embs/ (e.g. PPGBP_EMBEDDINGS_ROOT)")
ap.add_argument("--data_dir", required=True, help="PPGBP data dir containing the xlsx file")
ap.add_argument("--metadata_emb_type", default="onehot", help="Subdir name under metadata_embs/ (e.g. onehot)")
ap.add_argument("--sig_emb_type", default="placeholder", help="Subdir under sig_embs/ (created empty for now)")
args = ap.parse_args()
create_embedding_dirs(args.embeddings_root, args.metadata_emb_type, args.sig_emb_type)
build_metadata_embeddings_onehot(args.data_dir, args.embeddings_root, args.metadata_emb_type)
if __name__ == "__main__":
main()

View File

@@ -1,76 +0,0 @@
"""
One-shot: build PPGBP one-hot metadata embeddings and then package them into
the feature index at DATA_INDEX_ROOT for downstream use (RecruiterAgent, etc.).
Step 1: Create metadata_embs/onehot/ and sig_embs/ under PPGBP_EMBEDDINGS_ROOT,
write one-hot metadata vectors per subject.
Step 2: Run build_feature_index to produce index.json + embeddings.npy at
DATA_INDEX_ROOT (compatible with FeatureIndex and RecruiterAgent).
Usage:
export PPGBP_EMBEDDINGS_ROOT=/scratch/.../embeddings_full/ppgbp
export PPGBP_DATA_DIR=/path/to/ppgbp_data # dir with xlsx and 0_subject/
# or: export PPGBP_DATA_ROOT=/path/to/data/tsllm (same as data dir if xlsx lives there)
export DATA_INDEX_ROOT=/path/to/feature_index_output
python -m sc.preprocess.package_ppgbp_embeddings
Or with explicit args:
python -m sc.preprocess.package_ppgbp_embeddings \\
--embeddings_root $PPGBP_EMBEDDINGS_ROOT \\
--data_dir $PPGBP_DATA_DIR \\
--index_output_dir $DATA_INDEX_ROOT \\
--metadata_emb_type onehot \\
--embedding_type metadata_onehot
"""
from __future__ import annotations
import argparse
import os
import sys
def main() -> None:
ap = argparse.ArgumentParser(description="Build PPGBP metadata embeddings and package feature index.")
ap.add_argument("--embeddings_root", default=None, help="Default: PPGBP_EMBEDDINGS_ROOT env")
ap.add_argument("--data_dir", default=None, help="PPGBP data dir (xlsx + 0_subject/); default: PPGBP_DATA_DIR or PPGBP_DATA_ROOT env")
ap.add_argument("--index_output_dir", default=None, help="Where to write index.json + embeddings.npy; default: DATA_INDEX_ROOT env")
ap.add_argument("--metadata_emb_type", default="onehot")
ap.add_argument("--embedding_type", default="metadata_onehot", help="embedding_type string in index.json")
args = ap.parse_args()
embeddings_root = args.embeddings_root or os.environ.get("PPGBP_EMBEDDINGS_ROOT")
data_dir = args.data_dir or os.environ.get("PPGBP_DATA_DIR") or os.environ.get("PPGBP_DATA_ROOT") or embeddings_root
index_output_dir = args.index_output_dir or os.environ.get("DATA_INDEX_ROOT")
if not embeddings_root:
print("Set PPGBP_EMBEDDINGS_ROOT or pass --embeddings_root", file=sys.stderr)
sys.exit(1)
if not os.path.isdir(data_dir):
print(f"Data dir not found: {data_dir}", file=sys.stderr)
sys.exit(1)
if not index_output_dir:
print("Set DATA_INDEX_ROOT or pass --index_output_dir", file=sys.stderr)
sys.exit(1)
from sc.preprocess.build_ppgbp_embeddings import (
create_embedding_dirs,
build_metadata_embeddings_onehot,
)
create_embedding_dirs(embeddings_root, args.metadata_emb_type)
build_metadata_embeddings_onehot(data_dir, embeddings_root, args.metadata_emb_type)
embedding_root_for_index = os.path.join(embeddings_root, "metadata_embs", args.metadata_emb_type)
from sc.preprocess.build_feature_index import build_index
build_index(
embedding_root=embedding_root_for_index,
output_dir=index_output_dir,
embedding_type=args.embedding_type,
base_dir=data_dir,
)
print(f"[package_ppgbp_embeddings] Feature index written to {index_output_dir}")
if __name__ == "__main__":
main()

View File

@@ -1,409 +0,0 @@
"""
Data Shuffling Module for Sleep Stage Classification Experiment
This module provides functionality to shuffle user data for fair comparison
across different Queue policies (Confidence, Consistency, Random).
Features:
- Shuffle user data with fixed random seed for reproducibility
- Preserve original indices for tracking
- Save shuffled data to JSON for reuse
- Ensure all 3 experiments use identical shuffled order
Usage:
from sc.preprocess.shuffle_data import shuffle_user_data, load_shuffled_data
# Shuffle and save
shuffled_data = shuffle_user_data(user_id=5, seed=42, data_path="...")
# Load existing shuffled data
shuffled_data = load_shuffled_data(user_id=5, seed=42, output_dir="...")
"""
import os
import sys
import json
import random
import numpy as np
from typing import List, Dict, Any, Optional
from datetime import datetime
import datasets
from glob import glob
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
def shuffle_user_data(
user_id: int,
seed: int,
data_path: str,
output_dir: str = None,
save_to_file: bool = True,
) -> List[Dict[str, Any]]:
"""
Shuffle user data and optionally save to JSON file.
Args:
user_id: User ID to process (e.g., 5 or 10)
seed: Random seed for reproducibility
data_path: Path to SleepEDF data directory
output_dir: Directory to save shuffled data (default: data_path/shuffled)
save_to_file: Whether to save shuffled data to JSON
Returns:
List of shuffled samples with original_idx preserved
"""
# Set random seeds for reproducibility
random.seed(seed)
np.random.seed(seed)
# Format user_id with leading zeros (e.g., "05", "10")
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
# Load user test data
test_path = os.path.join(data_path, user_str, "2")
if not os.path.exists(test_path):
raise FileNotFoundError(f"Test data not found: {test_path}")
test_dataset = datasets.load_from_disk(test_path)
# Create list with original indices
samples_with_idx = []
for idx, sample in enumerate(test_dataset):
sample_dict = {
"original_idx": idx,
"label": sample["label"],
"features": sample["features"],
}
samples_with_idx.append(sample_dict)
# Shuffle the samples
shuffled_samples = samples_with_idx.copy()
random.shuffle(shuffled_samples)
# Add shuffled index
for shuffled_idx, sample in enumerate(shuffled_samples):
sample["shuffled_idx"] = shuffled_idx
# Save to file if requested
if save_to_file:
if output_dir is None:
output_dir = os.path.join(data_path, "shuffled")
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
metadata = {
"user_id": user_str,
"seed": seed,
"total_samples": len(shuffled_samples),
"created_at": datetime.now().isoformat(),
"data_path": data_path,
}
# Count samples per stage
stage_counts = {}
for sample in shuffled_samples:
label = sample["label"]
stage_counts[label] = stage_counts.get(label, 0) + 1
metadata["stage_distribution"] = stage_counts
output_data = {
"metadata": metadata,
"samples": shuffled_samples,
}
with open(output_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"[SHUFFLE] Saved shuffled data to: {output_path}")
print(f" User: {user_str}, Seed: {seed}")
print(f" Total samples: {len(shuffled_samples)}")
print(f" Stage distribution: {stage_counts}")
return shuffled_samples
def load_shuffled_data(
user_id: int,
seed: int,
output_dir: str = None,
data_path: str = None,
) -> List[Dict[str, Any]]:
"""
Load existing shuffled data from JSON file.
If file doesn't exist, create it by shuffling the data.
Args:
user_id: User ID to load
seed: Random seed used for shuffling
output_dir: Directory containing shuffled data files
data_path: Path to original data (used if shuffled file doesn't exist)
Returns:
List of shuffled samples
"""
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
if output_dir is None and data_path is not None:
output_dir = os.path.join(data_path, "shuffled")
if output_dir is None:
raise ValueError("Either output_dir or data_path must be provided")
file_path = os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
print(f"[SHUFFLE] Loaded existing shuffled data: {file_path}")
print(f" Total samples: {data['metadata']['total_samples']}")
return data["samples"]
else:
if data_path is None:
raise FileNotFoundError(f"Shuffled data not found: {file_path}")
print(f"[SHUFFLE] Shuffled data not found, creating new...")
return shuffle_user_data(
user_id=user_id,
seed=seed,
data_path=data_path,
output_dir=output_dir,
save_to_file=True,
)
def get_shuffled_file_path(
user_id: int,
seed: int,
data_path: str,
) -> str:
"""
Get the path to the shuffled data file.
Args:
user_id: User ID
seed: Random seed
data_path: Base data path
Returns:
Path to the shuffled data JSON file
"""
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
output_dir = os.path.join(data_path, "shuffled")
return os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
def verify_shuffle_consistency(
user_id: int,
seeds: List[int],
data_path: str,
) -> bool:
"""
Verify that shuffled data files exist and are consistent.
Args:
user_id: User ID to verify
seeds: List of seeds to verify
data_path: Base data path
Returns:
True if all files exist and have same sample count
"""
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
output_dir = os.path.join(data_path, "shuffled")
sample_counts = []
for seed in seeds:
file_path = os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
if not os.path.exists(file_path):
print(f"[VERIFY] Missing: {file_path}")
return False
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
sample_counts.append(data["metadata"]["total_samples"])
# All should have same count
if len(set(sample_counts)) != 1:
print(f"[VERIFY] Inconsistent sample counts: {sample_counts}")
return False
print(f"[VERIFY] User {user_str}: All {len(seeds)} seed files verified ({sample_counts[0]} samples each)")
return True
class ShuffledDataLoader:
"""
DataLoader that uses pre-shuffled data for consistent experiment comparison.
"""
def __init__(
self,
data_path: str,
user_id: int,
seed: int,
example_pool: str = "out",
):
"""
Initialize shuffled data loader.
Args:
data_path: Path to SleepEDF data
user_id: User ID (e.g., 5 or 10)
seed: Shuffle seed
example_pool: "out" (different users) or "in" (same user)
"""
self.data_path = data_path
self.user_id = user_id
self.seed = seed
self.example_pool = example_pool
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
self.user_str = user_str
# Load metadata
info_path = os.path.join(data_path, "info.json")
if not os.path.exists(info_path):
raise FileNotFoundError(f"Info file not found: {info_path}")
with open(info_path, "r", encoding="utf-8") as f:
self.metadata = json.load(f)
# Load shuffled test data
self.shuffled_samples = load_shuffled_data(
user_id=user_id,
seed=seed,
data_path=data_path,
)
# Load example dataset (same as original DataLoader)
self.example_dataset = datasets.Dataset.from_list([])
users = glob(os.path.join(data_path, "*"))
users = [os.path.basename(p) for p in users if os.path.isdir(p)]
users = [u for u in users if u not in ["info.json", "shuffled"]]
for user in users:
if example_pool == "out" and user == user_str:
continue
if example_pool == "in" and user != user_str:
continue
example_path = os.path.join(data_path, user, "1")
if os.path.exists(example_path):
user_dataset = datasets.load_from_disk(example_path)
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
# Shuffle example dataset with fixed seed
self.example_dataset = self.example_dataset.shuffle(seed=0)
print(f"[ShuffledDataLoader] User: {user_str}, Seed: {seed}")
print(f" Test samples: {len(self.shuffled_samples)}")
print(f" Example samples: {len(self.example_dataset)}")
def __len__(self):
return len(self.shuffled_samples)
def __getitem__(self, idx):
return self.shuffled_samples[idx]
def __iter__(self):
for sample in self.shuffled_samples:
yield sample
def get_examples(self):
return self.example_dataset
def get_metadata(self):
return self.metadata
def get_sensor_info(self):
return self.metadata["feature"]
def get_task_info(self):
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
classes_info = "\n".join(classes_info)
task_info += f"**Classes**:\n{classes_info}"
return task_info
def get_classes_info(self):
return list(self.metadata["class"].keys())
def prepare_all_shuffled_data(
data_path: str,
users: List[int] = [5, 10],
seeds: List[int] = [42, 123, 456],
) -> None:
"""
Prepare all shuffled data files for the experiment.
Args:
data_path: Path to SleepEDF data
users: List of user IDs
seeds: List of shuffle seeds
"""
print("=" * 60)
print("Preparing Shuffled Data for Experiment")
print("=" * 60)
for user_id in users:
for seed in seeds:
print(f"\nProcessing User {user_id}, Seed {seed}...")
shuffle_user_data(
user_id=user_id,
seed=seed,
data_path=data_path,
save_to_file=True,
)
print("\n" + "=" * 60)
print("Verification")
print("=" * 60)
for user_id in users:
verify_shuffle_consistency(user_id, seeds, data_path)
print("\n[DONE] All shuffled data prepared.")
if __name__ == "__main__":
import fire
def main(
data_path: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
users = "5,10",
seeds = "42,123,456",
):
"""
Prepare shuffled data for experiments.
Args:
data_path: Path to SleepEDF data
users: Comma-separated user IDs or list/tuple
seeds: Comma-separated shuffle seeds or list/tuple
"""
# Handle both string and tuple/list inputs
if isinstance(users, str):
user_list = [int(u.strip()) for u in users.split(",")]
elif isinstance(users, (list, tuple)):
user_list = [int(u) for u in users]
else:
user_list = [int(users)]
if isinstance(seeds, str):
seed_list = [int(s.strip()) for s in seeds.split(",")]
elif isinstance(seeds, (list, tuple)):
seed_list = [int(s) for s in seeds]
else:
seed_list = [int(seeds)]
prepare_all_shuffled_data(
data_path=data_path,
users=user_list,
seeds=seed_list,
)
fire.Fire(main)

View File

@@ -1,463 +0,0 @@
"""
Confidence-based Queue Policy Experiment Runner
This module implements the CONFIDENCE-based queue update policy:
- Queue is updated based on model confidence scores
- Higher confidence examples are retained in the queue
- Lower confidence examples are evicted
Usage:
python -m sc.run_confidence --user_id=5 --shuffle_seed=42
python -m sc.run_confidence --user_id=10 --shuffle_seed=123
"""
import os
import sys
import asyncio
import yaml
import json
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Any, Optional
from sklearn.metrics import f1_score, precision_score, recall_score
from fire import Fire
# Add project root to path for relative imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.example_queue import Queue
from sc.core.agent_pool import AgentPool
from sc import debug_log
log = debug_log.log
async def run_confidence_experiment(
dataloader: ShuffledDataLoader,
model_pool,
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> List[Dict[str, Any]]:
"""
Run confidence-based queue policy experiment.
Queue Update Policy:
- After each inference, rank queue elements by confidence score
- Evict the lowest confidence element
- Add a new random ICL example set
Args:
dataloader: ShuffledDataLoader with pre-shuffled test data
model_pool: Async model pool for LLM inference
config: Experiment configuration
user_id: User ID being processed
shuffle_seed: Shuffle seed for reproducibility
Returns:
List of result dictionaries
"""
# Set seeds for reproducibility within experiment
random.seed(shuffle_seed)
np.random.seed(shuffle_seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
log(f"[ERROR] No examples found for user {user_id}")
return []
# Build class_indices for Queue initialization
class_indices = {}
for idx, example in enumerate(example_dataset):
label = example["label"]
if label not in class_indices:
class_indices[label] = []
class_indices[label].append(idx)
# Initialize Queue
queue_size = config.get("queue_size", 5)
ex_queue = Queue(class_indices, queue_size)
# Tracking variables
results = []
cumulative_correct = 0
window_size = config.get("tracking_window", 20)
recent_results = []
confidence_history = []
consistency_history = []
all_predictions = []
all_ground_truths = []
debug_log.log_policy_experiment_header(
"CONFIDENCE-BASED",
user_id,
shuffle_seed,
len(dataloader),
queue_size,
)
for processed_count, sample in enumerate(dataloader):
ex_queue.set_current_time(processed_count)
# Log queue state before processing
debug_log.log_policy_queue_state("CONFIDENCE", processed_count, user_id, ex_queue)
# Create agent pool
agent_pool = AgentPool(log_path=config["log_path"])
try:
for queue_idx, ex_idcs in enumerate(ex_queue):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
model_pool=model_pool,
task_info=dataloader.get_task_info(),
classes_info=dataloader.get_classes_info(),
sensor_info=dataloader.get_sensor_info(),
sample=sample,
examples=examples,
log_path=config["log_path"],
)
agent_pool.add_agent(agent)
except Exception as e:
log(f"[ERROR] Failed to create agents: {e}")
import traceback
traceback.print_exc()
continue
if len(agent_pool.agents) == 0:
log(f"[WARN] No agents created for sample {processed_count}")
continue
# Run parallel interpretation
try:
interpretation_result = await agent_pool.run_parallel_interpretation()
except Exception as e:
log(f"[ERROR] Interpretation failed: {e}")
import traceback
traceback.print_exc()
continue
if interpretation_result is None:
log(f"[WARN] Interpretation failed for sample {processed_count}")
continue
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
# Evaluate result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
confidence_history.append(avg_confidence)
consistency_history.append(consistency)
all_predictions.append(answer)
all_ground_truths.append(ground_truth)
# Performance logging
debug_log.log_policy_result(
"CONFIDENCE",
processed_count,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
)
# CONFIDENCE-BASED Queue Update
if responses:
# Build confidence map from agent responses
confidence_map = {}
for idx, response in responses.items():
confidence_map[idx] = response.get("CONFIDENCE", 0.0)
debug_log.log_confidence_map(responses, confidence_map)
# Update queue by confidence (evict lowest, add new random)
ex_queue.update_by_confidence(confidence_map)
ex_queue.increment_usage(list(responses.keys()))
debug_log.log_policy_queue_state_after("Confidence", processed_count, ex_queue)
# Store result
result = {
"sample_idx": processed_count,
"original_idx": sample.get("original_idx", processed_count),
"shuffled_idx": sample.get("shuffled_idx", processed_count),
"answer": answer,
"ground_truth": ground_truth,
"is_correct": is_correct,
"confidence": avg_confidence,
"consistency": consistency,
"cumulative_accuracy": cumulative_accuracy,
"window_accuracy": window_accuracy,
"experiment_type": "confidence",
"user_id": user_id,
"shuffle_seed": shuffle_seed,
}
results.append(result)
# Final statistics
survival_summary = ex_queue.get_survival_summary()
debug_log.log_final_policy_survival("CONFIDENCE", user_id, shuffle_seed, survival_summary)
if results:
results[-1]["queue_survival_stats"] = survival_summary
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
return results
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
"""Compute comprehensive experiment statistics."""
if not results:
return {}
# Overall metrics
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Confidence and consistency averages
confidences = [r.get("confidence", 0) for r in results]
consistencies = [r.get("consistency", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-stage accuracy
stage_correct = {}
stage_total = {}
for r in results:
gt = r.get("ground_truth", "UNKNOWN")
stage_total[gt] = stage_total.get(gt, 0) + 1
if r.get("is_correct", False):
stage_correct[gt] = stage_correct.get(gt, 0) + 1
stage_accuracy = {}
for stage in stages:
if stage in stage_total:
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
else:
stage_accuracy[stage] = 0.0
# F1 Score and Macro metrics
predictions = [r.get("answer", "") for r in results]
ground_truths = [r.get("ground_truth", "") for r in results]
try:
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
except Exception:
macro_f1 = macro_precision = macro_recall = 0.0
# Temporal analysis
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = second_half_acc = accuracy
improvement = 0
# Learning curve (every 10 samples)
learning_curve = []
for i in range(0, len(results), 10):
chunk = results[:i+10]
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
learning_curve.append({
"sample_idx": min(i+10, len(results)),
"cumulative_accuracy": chunk_acc,
})
# Convergence speed (90% of final accuracy)
final_accuracy = accuracy
convergence_threshold = 0.9 * final_accuracy
convergence_idx = None
running_correct = 0
for i, r in enumerate(results):
running_correct += 1 if r.get("is_correct", False) else 0
running_acc = running_correct / (i + 1)
if running_acc >= convergence_threshold:
convergence_idx = i
break
# Queue statistics
queue_stats = {}
for r in results:
if "queue_survival_stats" in r:
queue_stats = r["queue_survival_stats"]
break
return {
"experiment_type": "confidence",
"user_id": results[0].get("user_id") if results else None,
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
"total_samples": total,
"correct": correct,
"accuracy": accuracy,
"avg_confidence": float(avg_confidence),
"avg_consistency": float(avg_consistency),
"stage_accuracy": stage_accuracy,
"stage_sample_counts": stage_total,
"macro_f1": float(macro_f1),
"macro_precision": float(macro_precision),
"macro_recall": float(macro_recall),
"temporal_analysis": {
"first_half_accuracy": first_half_acc,
"second_half_accuracy": second_half_acc,
"improvement": improvement,
},
"learning_curve": learning_curve,
"convergence_idx": convergence_idx,
"queue_stats": queue_stats,
}
def save_results(
results: List[Dict[str, Any]],
stats: Dict[str, Any],
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> None:
"""Save experiment results and statistics."""
log_path = config["log_path"]
# Create user/seed specific directory
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
os.makedirs(output_dir, exist_ok=True)
# Save statistics
stats_path = os.path.join(output_dir, "statistics.json")
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Statistics: {stats_path}")
# Save results
results_path = os.path.join(output_dir, "results.json")
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Results: {results_path}")
# Save config
config_path = os.path.join(output_dir, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
log(f"[SAVE] Config: {config_path}")
def main(
config_path: str = "sc/config/experiment_confidence.yaml",
user_id: int = 5,
shuffle_seed: int = 42,
) -> None:
"""
Run Confidence-based Queue Policy experiment.
Args:
config_path: Path to YAML configuration file
user_id: User ID to process (5 or 10)
shuffle_seed: Shuffle seed for data order (42, 123, or 456)
Example:
python -m sc.run_confidence --user_id=5 --shuffle_seed=42
python -m sc.run_confidence --user_id=10 --shuffle_seed=123
"""
log(f"[MAIN] Loading config: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
# Override with CLI arguments
config["user_id"] = user_id
config["shuffle_seed"] = shuffle_seed
# Create unique log path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
os.makedirs(config["log_path"], exist_ok=True)
# Print experiment info
debug_log.log_policy_main_header(
"CONFIDENCE-BASED",
user_id,
shuffle_seed,
config.get("queue_size", 5),
len(config.get("models", [])),
config["log_path"],
)
# Load shuffled data
debug_log.log_policy_loading_data(user_id, shuffle_seed)
dataloader = ShuffledDataLoader(
data_path=config["data_path"],
user_id=user_id,
seed=shuffle_seed,
example_pool=config.get("example_pool", "out"),
)
# Load models
debug_log.log_policy_loading_models()
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
max_new_tokens=config.get("max_new_tokens", 1024),
)
# Run experiment
debug_log.log_policy_start(label="experiment")
results = asyncio.run(run_confidence_experiment(
dataloader=dataloader,
model_pool=model_pool,
config=config,
user_id=user_id,
shuffle_seed=shuffle_seed,
))
# Compute statistics
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
stats = compute_statistics(results, stages)
# Print final summary
temporal = stats.get("temporal_analysis", {})
debug_log.log_policy_complete_summary(
"CONFIDENCE",
user_id,
shuffle_seed,
stats,
stats.get("stage_accuracy", {}),
stats.get("stage_sample_counts", {}),
temporal,
)
# Save results
save_results(results, stats, config, user_id, shuffle_seed)
log(f"\n[MAIN] Results saved to: {config['log_path']}")
if __name__ == "__main__":
Fire(main)

View File

@@ -1,473 +0,0 @@
"""
Consistency-based Queue Policy Experiment Runner
This module implements the CONSISTENCY-based queue update policy:
- Queue is updated based on SC consensus/agreement scores
- Higher consistency (more agents agree) examples are retained
- Lower consistency examples are evicted
Consistency Score = (Number of agents agreeing with majority) / (Total agents)
Usage:
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
"""
import os
import sys
import asyncio
import yaml
import json
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Any, Optional
from sklearn.metrics import f1_score, precision_score, recall_score
from fire import Fire
# Add project root to path for relative imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.example_queue import Queue
from sc.core.agent_pool import AgentPool
from sc import debug_log
log = debug_log.log
async def run_consistency_experiment(
dataloader: ShuffledDataLoader,
model_pool,
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> List[Dict[str, Any]]:
"""
Run consistency-based queue policy experiment.
Queue Update Policy:
- After each inference, calculate per-agent consistency
- Consistency = how many other agents agree with this agent's answer
- Rank queue elements by consistency score
- Evict the lowest consistency element
- Add a new random ICL example set
Args:
dataloader: ShuffledDataLoader with pre-shuffled test data
model_pool: Async model pool for LLM inference
config: Experiment configuration
user_id: User ID being processed
shuffle_seed: Shuffle seed for reproducibility
Returns:
List of result dictionaries
"""
# Set seeds for reproducibility within experiment
random.seed(shuffle_seed)
np.random.seed(shuffle_seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
log(f"[ERROR] No examples found for user {user_id}")
return []
# Build class_indices for Queue initialization
class_indices = {}
for idx, example in enumerate(example_dataset):
label = example["label"]
if label not in class_indices:
class_indices[label] = []
class_indices[label].append(idx)
# Initialize Queue
queue_size = config.get("queue_size", 5)
ex_queue = Queue(class_indices, queue_size)
# Tracking variables
results = []
cumulative_correct = 0
window_size = config.get("tracking_window", 20)
recent_results = []
confidence_history = []
consistency_history = []
all_predictions = []
all_ground_truths = []
debug_log.log_policy_experiment_header(
"CONSISTENCY-BASED",
user_id,
shuffle_seed,
len(dataloader),
queue_size,
)
for processed_count, sample in enumerate(dataloader):
ex_queue.set_current_time(processed_count)
# Log queue state before processing
debug_log.log_policy_queue_state("CONSISTENCY", processed_count, user_id, ex_queue)
# Create agent pool
agent_pool = AgentPool(log_path=config["log_path"])
try:
for queue_idx, ex_idcs in enumerate(ex_queue):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
model_pool=model_pool,
task_info=dataloader.get_task_info(),
classes_info=dataloader.get_classes_info(),
sensor_info=dataloader.get_sensor_info(),
sample=sample,
examples=examples,
log_path=config["log_path"],
)
agent_pool.add_agent(agent)
except Exception as e:
log(f"[ERROR] Failed to create agents: {e}")
import traceback
traceback.print_exc()
continue
if len(agent_pool.agents) == 0:
log(f"[WARN] No agents created for sample {processed_count}")
continue
# Run parallel interpretation
try:
interpretation_result = await agent_pool.run_parallel_interpretation()
except Exception as e:
log(f"[ERROR] Interpretation failed: {e}")
import traceback
traceback.print_exc()
continue
if interpretation_result is None:
log(f"[WARN] Interpretation failed for sample {processed_count}")
continue
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
# Evaluate result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
confidence_history.append(avg_confidence)
consistency_history.append(consistency)
all_predictions.append(answer)
all_ground_truths.append(ground_truth)
# Performance logging
debug_log.log_policy_result(
"CONSISTENCY",
processed_count,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
)
# CONSISTENCY-BASED Queue Update
if responses:
# Calculate per-agent consistency: how many other agents agree
all_answers = [r.get("ANSWER") for r in responses.values()]
consistency_map = {}
for idx, response in responses.items():
agent_answer = response.get("ANSWER")
# Consistency = ratio of agents (including self) that agree
agreement_count = all_answers.count(agent_answer)
agent_consistency = agreement_count / len(all_answers) if all_answers else 0
consistency_map[idx] = agent_consistency
debug_log.log_consistency_map(responses, consistency_map)
# Update queue by consistency (reuses confidence method with consistency scores)
ex_queue.update_by_confidence(consistency_map)
ex_queue.increment_usage(list(responses.keys()))
debug_log.log_policy_queue_state_after("Consistency", processed_count, ex_queue)
# Store result
result = {
"sample_idx": processed_count,
"original_idx": sample.get("original_idx", processed_count),
"shuffled_idx": sample.get("shuffled_idx", processed_count),
"answer": answer,
"ground_truth": ground_truth,
"is_correct": is_correct,
"confidence": avg_confidence,
"consistency": consistency,
"cumulative_accuracy": cumulative_accuracy,
"window_accuracy": window_accuracy,
"experiment_type": "consistency",
"user_id": user_id,
"shuffle_seed": shuffle_seed,
}
results.append(result)
# Final statistics
survival_summary = ex_queue.get_survival_summary()
debug_log.log_final_policy_survival("CONSISTENCY", user_id, shuffle_seed, survival_summary)
if results:
results[-1]["queue_survival_stats"] = survival_summary
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
return results
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
"""Compute comprehensive experiment statistics."""
if not results:
return {}
# Overall metrics
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Confidence and consistency averages
confidences = [r.get("confidence", 0) for r in results]
consistencies = [r.get("consistency", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-stage accuracy
stage_correct = {}
stage_total = {}
for r in results:
gt = r.get("ground_truth", "UNKNOWN")
stage_total[gt] = stage_total.get(gt, 0) + 1
if r.get("is_correct", False):
stage_correct[gt] = stage_correct.get(gt, 0) + 1
stage_accuracy = {}
for stage in stages:
if stage in stage_total:
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
else:
stage_accuracy[stage] = 0.0
# F1 Score and Macro metrics
predictions = [r.get("answer", "") for r in results]
ground_truths = [r.get("ground_truth", "") for r in results]
try:
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
except Exception:
macro_f1 = macro_precision = macro_recall = 0.0
# Temporal analysis
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = second_half_acc = accuracy
improvement = 0
# Learning curve (every 10 samples)
learning_curve = []
for i in range(0, len(results), 10):
chunk = results[:i+10]
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
learning_curve.append({
"sample_idx": min(i+10, len(results)),
"cumulative_accuracy": chunk_acc,
})
# Convergence speed (90% of final accuracy)
final_accuracy = accuracy
convergence_threshold = 0.9 * final_accuracy
convergence_idx = None
running_correct = 0
for i, r in enumerate(results):
running_correct += 1 if r.get("is_correct", False) else 0
running_acc = running_correct / (i + 1)
if running_acc >= convergence_threshold:
convergence_idx = i
break
# Queue statistics
queue_stats = {}
for r in results:
if "queue_survival_stats" in r:
queue_stats = r["queue_survival_stats"]
break
return {
"experiment_type": "consistency",
"user_id": results[0].get("user_id") if results else None,
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
"total_samples": total,
"correct": correct,
"accuracy": accuracy,
"avg_confidence": float(avg_confidence),
"avg_consistency": float(avg_consistency),
"stage_accuracy": stage_accuracy,
"stage_sample_counts": stage_total,
"macro_f1": float(macro_f1),
"macro_precision": float(macro_precision),
"macro_recall": float(macro_recall),
"temporal_analysis": {
"first_half_accuracy": first_half_acc,
"second_half_accuracy": second_half_acc,
"improvement": improvement,
},
"learning_curve": learning_curve,
"convergence_idx": convergence_idx,
"queue_stats": queue_stats,
}
def save_results(
results: List[Dict[str, Any]],
stats: Dict[str, Any],
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> None:
"""Save experiment results and statistics."""
log_path = config["log_path"]
# Create user/seed specific directory
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
os.makedirs(output_dir, exist_ok=True)
# Save statistics
stats_path = os.path.join(output_dir, "statistics.json")
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Statistics: {stats_path}")
# Save results
results_path = os.path.join(output_dir, "results.json")
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Results: {results_path}")
# Save config
config_path = os.path.join(output_dir, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
log(f"[SAVE] Config: {config_path}")
def main(
config_path: str = "sc/config/experiment_consistency.yaml",
user_id: int = 5,
shuffle_seed: int = 42,
) -> None:
"""
Run Consistency-based Queue Policy experiment.
Args:
config_path: Path to YAML configuration file
user_id: User ID to process (5 or 10)
shuffle_seed: Shuffle seed for data order (42, 123, or 456)
Example:
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
"""
log(f"[MAIN] Loading config: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
# Override with CLI arguments
config["user_id"] = user_id
config["shuffle_seed"] = shuffle_seed
# Create unique log path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
os.makedirs(config["log_path"], exist_ok=True)
# Print experiment info
debug_log.log_policy_main_header(
"CONSISTENCY-BASED",
user_id,
shuffle_seed,
config.get("queue_size", 5),
len(config.get("models", [])),
config["log_path"],
)
# Load shuffled data
debug_log.log_policy_loading_data(user_id, shuffle_seed)
dataloader = ShuffledDataLoader(
data_path=config["data_path"],
user_id=user_id,
seed=shuffle_seed,
example_pool=config.get("example_pool", "out"),
)
# Load models
debug_log.log_policy_loading_models()
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
max_new_tokens=config.get("max_new_tokens", 1024),
)
# Run experiment
debug_log.log_policy_start(label="experiment")
results = asyncio.run(run_consistency_experiment(
dataloader=dataloader,
model_pool=model_pool,
config=config,
user_id=user_id,
shuffle_seed=shuffle_seed,
))
# Compute statistics
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
stats = compute_statistics(results, stages)
# Print final summary
temporal = stats.get("temporal_analysis", {})
debug_log.log_policy_complete_summary(
"CONSISTENCY",
user_id,
shuffle_seed,
stats,
stats.get("stage_accuracy", {}),
stats.get("stage_sample_counts", {}),
temporal,
)
# Save results
save_results(results, stats, config, user_id, shuffle_seed)
log(f"\n[MAIN] Results saved to: {config['log_path']}")
if __name__ == "__main__":
Fire(main)

View File

@@ -1,390 +0,0 @@
"""
Recruiter-based Self-Consistency runner for PPGBP.
Uses RecruiterAgent (MAB) to select example sets, self_certainty(logits) as reward,
and borda_vote for aggregation. Requires PPGBPLoader and optional feature index.
Usage:
python -m sc.run_recruiter sc/config/ppgbp_recruiter.yaml
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
from collections import deque
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import yaml
from fire import Fire
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.core.data_loader import PPGBPLoader, get_loaders
from sc.core.example_queue import Queue
from sc.core.model import Model, load_models
from sc.core.recruiter_agent import RecruiterAgent, borda_vote, self_certainty
from sc.core.recruiter_agent import FeatureIndex
def _ppgbp_sample_to_dict(metadata: Dict, signal, subject_id: int, label_key: str = "label") -> Dict[str, Any]:
"""Turn (metadata, signal) into a sample dict with user_id, sample_id, label, features."""
# Try to find label in metadata using label_key, then "hypertension", then "Hypertension", then "class", then "Class"
label = str(metadata.get(label_key))
if label == "None" or label == "unknown":
for key in ["hypertension", "Hypertension", "class", "Class"]:
val = metadata.get(key)
if val is not None:
label = str(val)
break
if label == "None":
label = "unknown"
features = {k: str(v) for k, v in metadata.items() if k != label_key and k.lower() != "label"}
return {
"user_id": str(subject_id),
"sample_id": subject_id,
"idx": subject_id,
"label": label,
"features": features,
}
def _build_prompt(system_msg: str, task_info: str, classes_info: List[str], test_sample: Dict, examples: List[Dict]) -> List[Dict[str, str]]:
"""Build chat messages for the model."""
parts = [f"**Task**: {task_info}\n\n**Classes**: {', '.join(classes_info)}\n\n"]
parts.append("**Examples**\n")
for ex in examples:
parts.append(f"*Example of {ex['label']}*:\n")
for k, v in (ex.get("features") or {}).items():
parts.append(f" - {k}: {v}\n")
parts.append("\n**Current sample**\n")
for k, v in (test_sample.get("features") or {}).items():
parts.append(f" - {k}: {v}\n")
parts.append(f"\nRespond in JSON: {{\"REASON\": \"...\", \"CONFIDENCE\": 0.0-1.0, \"ANSWER\": \"<class>\"}}")
content = "".join(parts)
return [
{"role": "system", "content": system_msg},
{"role": "user", "content": content},
]
def _parse_answer(text: str, classes: List[str]) -> Optional[str]:
"""Extract ANSWER from JSON in text."""
try:
start = text.find("{")
end = text.rfind("}") + 1
if start >= 0 and end > start:
obj = json.loads(text[start:end])
ans = (obj.get("ANSWER") or "").strip()
if ans in classes:
return ans
except Exception:
pass
return None
def make_gt_scorer(
num_arms: int,
best_arm_id: int = 20,
sigma: float = 5.0,
) -> Callable[[int], float]:
"""
Ground-truth: deterministic. The distribution *over arms* is normal: reward(arm_id)
is a Gaussian in arm index with peak at best_arm_id. So reward(i) = exp(-(i - best)^2 / (2 sigma^2)).
Score = -reward so lower = better. No randomness.
"""
best_arm_id = max(0, min(best_arm_id, num_arms - 1))
sigma = max(1e-6, float(sigma))
def scorer(arm_id: int) -> float:
if arm_id < 0 or arm_id >= num_arms:
return 0.0
reward = np.exp(-((arm_id - best_arm_id) ** 2) / (2 * sigma**2))
return -float(reward)
return scorer
async def run_single_user(
config: Dict[str, Any],
model_pool: Optional[Any],
test_loader: Optional[PPGBPLoader],
train_loader: PPGBPLoader,
seed: int,
) -> List[Dict[str, Any]]:
"""Run recruiter pipeline. If model_pool is None, use ground-truth scorer (arm_id -> score) for MAB convergence."""
queue_size = config.get("queue_size", 5)
recruit_size = config.get("recruit_size", 2)
n_way = config.get("n_way", 2)
k_shot = config.get("k_shot", 1)
task_info = config.get("task_info", "Classify the subject.")
classes_info = config.get("classes_info", ["Normal", "Hypertension"])
system_msg = config.get("system_message", "You are a helpful assistant. Respond in JSON with REASON, CONFIDENCE, ANSWER.")
use_gt_only = model_pool is None
np.random.seed(seed)
train_loader._indices = np.arange(len(train_loader.subject_ids))
if train_loader.shuffle:
train_loader.rng.shuffle(train_loader._indices)
example_dataset = []
for idx in train_loader._indices:
sid = train_loader.subject_ids[idx]
meta = train_loader._get_metadata_dict(sid)
signal = train_loader._load_signal(sid)
example_dataset.append(_ppgbp_sample_to_dict(meta, signal, int(sid)))
if len(example_dataset) < n_way * k_shot:
print("[run_recruiter] Not enough train samples")
return []
classes = list({s["label"] for s in example_dataset})
if not classes:
classes = classes_info
class_indices = {c: [i for i, s in enumerate(example_dataset) if s["label"] == c] for c in classes}
for c in classes:
if not class_indices[c]:
class_indices[c] = [0]
dataset_idx = {(str(s["user_id"]), s["sample_id"]): i for i, s in enumerate(example_dataset)}
feature_index = None
if config.get("feature_root") and os.path.isdir(config["feature_root"]):
try:
feature_index = FeatureIndex(config["feature_root"])
except Exception as e:
print(f"[run_recruiter] Feature index load failed: {e}")
recruiter = RecruiterAgent(
source_samples=example_dataset,
classes=classes,
feature_index=feature_index,
target_user_id=None,
k_shot=k_shot,
queue_size=queue_size,
num_arms=config.get("num_arms", 50),
n_das=config.get("n_das", 5),
delta=config.get("delta", 0.05),
epsilon=config.get("epsilon", 0.11),
seed=seed,
)
num_arms = len(recruiter.pool.arms)
if use_gt_only:
gt_steps = config.get("gt_steps", 200)
gt_best_arm_id = config.get("gt_best_arm_id", 20)
gt_sigma = config.get("gt_sigma", 5.0)
gt_scorer = make_gt_scorer(num_arms, best_arm_id=gt_best_arm_id, sigma=gt_sigma)
print(f"[run_recruiter] GT-only mode: {gt_steps} steps, best_arm_id={gt_best_arm_id}, sigma={gt_sigma}")
initial_sets = recruiter.recruit(queue_size)
initial_cases = []
slot_arms = []
for arm_idx, ex_set in initial_sets:
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
if key in dataset_idx:
case.append(dataset_idx[key])
else:
for i, s in enumerate(example_dataset):
if s.get("user_id") == str(d.get("user_id")) and s.get("sample_id") == d.get("sample_id"):
case.append(i)
break
if len(case) >= n_way:
initial_cases.append(case[:n_way])
slot_arms.append(arm_idx)
while len(initial_cases) < queue_size and len(initial_sets) > len(initial_cases):
arm_idx, ex_set = initial_sets[len(initial_cases)]
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
case.append(dataset_idx.get(key, 0))
if len(case) < n_way:
case.extend([0] * (n_way - len(case)))
initial_cases.append(case[:n_way])
slot_arms.append(arm_idx)
ex_queue = Queue(class_indices, queue_size)
ex_queue._queue = deque(initial_cases[:queue_size], maxlen=queue_size)
for case in ex_queue._queue:
ex_queue._register_stats(case)
slot_arms = slot_arms[:queue_size]
results = []
processed = 0
cumulative_correct = 0
if use_gt_only:
# Loop: query gt_scorer(arm_id) for each slot, update MAB, recruit, update queue
for step in range(gt_steps):
ex_queue.set_current_time(step)
queue_cases = list(ex_queue._queue)
run_results = []
for slot_i in range(len(slot_arms)):
arm_idx = slot_arms[slot_i]
score = gt_scorer(arm_idx)
response_dict = {"ANSWER": None, "REASON": "gt", "CONFIDENCE": 0.5}
run_results.append((arm_idx, response_dict, score))
recruiter.update(run_results)
new_sets = recruiter.recruit(recruit_size)
new_cases = []
new_arm_ids = []
for arm_idx, ex_set in new_sets:
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
case.append(dataset_idx.get(key, 0))
if len(case) < n_way:
case.extend([0] * (n_way - len(case)))
new_cases.append(case[:n_way])
new_arm_ids.append(arm_idx)
ex_queue.update_with_recruiter(run_results, new_cases, recruit_size)
scores_with_idx = [(i, run_results[i][2]) for i in range(len(run_results))]
scores_with_idx.sort(key=lambda x: x[1], reverse=True)
to_evict = {scores_with_idx[j][0] for j in range(min(recruit_size, len(scores_with_idx)))}
kept = [i for i in range(len(slot_arms)) if i not in to_evict]
slot_arms = [slot_arms[i] for i in kept] + new_arm_ids[:recruit_size]
results.append({
"sample_idx": step,
"answer": None,
"ground_truth": None,
"is_correct": None,
"gt_step": step,
"queue_arms": list(slot_arms),
"step_rewards": [r[2] for r in run_results],
})
print(f"[GT Summary] Final queue arms: {slot_arms}")
return results
# Model mode: iterate test loader, call LLM, self_certainty, borda_vote
for step, (metadata, signal) in enumerate(test_loader):
test_sample = _ppgbp_sample_to_dict(metadata, signal, int(metadata.get("subject_ID", step)))
ex_queue.set_current_time(processed)
queue_cases = list(ex_queue._queue)
run_results = []
for slot_i, case in enumerate(queue_cases):
examples = [example_dataset[i] for i in case if i < len(example_dataset)]
if len(examples) < n_way:
continue
messages = _build_prompt(system_msg, task_info, classes_info, test_sample, examples)
text, logits = await model_pool.invoke(messages)
score = self_certainty(logits)
parsed = _parse_answer(text, classes_info)
response_dict = {"ANSWER": parsed, "REASON": text[:200], "CONFIDENCE": 0.5}
arm_idx = slot_arms[slot_i] if slot_i < len(slot_arms) else slot_i
run_results.append((arm_idx, response_dict, score))
if not run_results:
processed += 1
continue
answer = borda_vote(run_results, valid_classes=classes_info)
gt = test_sample.get("label", "")
is_correct = answer == gt
cumulative_correct += 1 if is_correct else 0
acc = cumulative_correct / (processed + 1)
recruiter.update(run_results)
new_sets = recruiter.recruit(recruit_size)
new_cases = []
new_arm_ids = []
for arm_idx, ex_set in new_sets:
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
case.append(dataset_idx.get(key, 0))
if len(case) < n_way:
case.extend([0] * (n_way - len(case)))
new_cases.append(case[:n_way])
new_arm_ids.append(arm_idx)
ex_queue.update_with_recruiter(run_results, new_cases, recruit_size)
scores_with_idx = [(i, run_results[i][2]) for i in range(len(run_results))]
scores_with_idx.sort(key=lambda x: x[1], reverse=True)
to_evict = {scores_with_idx[j][0] for j in range(min(recruit_size, len(scores_with_idx)))}
kept = [i for i in range(len(slot_arms)) if i not in to_evict]
slot_arms = [slot_arms[i] for i in kept] + new_arm_ids[:recruit_size]
results.append({
"sample_idx": processed,
"answer": answer,
"ground_truth": gt,
"is_correct": is_correct,
"cumulative_accuracy": acc,
})
processed += 1
return results
def _expand_env_in_config(obj: Any) -> Any:
"""Recursively expand $VAR and ${VAR} in config strings."""
if isinstance(obj, dict):
return {k: _expand_env_in_config(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_expand_env_in_config(v) for v in obj]
if isinstance(obj, str):
return os.path.expandvars(obj)
return obj
def main(config_path: str) -> None:
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
config = _expand_env_in_config(config)
if not config.get("model_path") and config.get("models"):
config["model_path"] = config["models"][0] if isinstance(config["models"], list) else config["models"]
model_path = config.get("model_path") or ""
if isinstance(model_path, str):
model_path = model_path.strip()
use_gt_only = not model_path
if use_gt_only:
config["use_gt_only"] = True
print("[run_recruiter] model_path is null: using ground-truth scorer (arm_id -> score) for MAB convergence")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if "log_path" in config:
config["log_path"] = f"{config['log_path']}_{timestamp}"
pool = None
if not use_gt_only:
pool = load_models(
[model_path],
temperature=config.get("temperature", 0.7),
max_new_tokens=config.get("max_new_tokens", 256),
)
data_path = config["data_path"]
train_loader, test_loader = get_loaders(data_path, seed=config.get("seed", 42))
if use_gt_only:
test_loader = None # not used in GT loop
seeds = range(config.get("num_seeds", 1))
all_results = []
for seed in seeds:
res = asyncio.run(run_single_user(config, pool, test_loader, train_loader, seed))
all_results.extend(res)
if use_gt_only:
print(f"GT mode: completed {len(all_results)} steps")
else:
correct = sum(1 for r in all_results if r.get("is_correct"))
total = len(all_results)
acc = correct / total if total else 0
print(f"Accuracy: {correct}/{total} = {acc:.4f}")
log_path = config.get("log_path", "./logs/ppgbp_recruiter")
os.makedirs(log_path, exist_ok=True)
with open(os.path.join(log_path, "results.json"), "w") as f:
json.dump({
"use_gt_only": use_gt_only,
"accuracy": (sum(1 for r in all_results if r.get("is_correct")) / len(all_results)) if all_results and not use_gt_only else None,
"correct": sum(1 for r in all_results if r.get("is_correct")) if not use_gt_only else None,
"total": len(all_results),
"results": all_results,
}, f, indent=2)
if __name__ == "__main__":
Fire(main)

View File

@@ -1,467 +0,0 @@
"""
Self-Consistency Experiment Runner for Sleep Stage Classification
This module implements Self-Consistency methodology for sleep stage classification:
- Sample N times with the same prompt
- Use majority voting for final answer
- Support confidence-based tie-breaking
Usage:
python -m sc.run_sc sc/config/sleepedf_sc.yaml
"""
import os
import sys
import asyncio
import yaml
import json
import numpy as np
from datetime import datetime
from typing import List, Dict, Any
import time
from glob import glob
from fire import Fire
# Add project root to path for relative imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.core.data_loader import DataLoader
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.example_queue import Queue
from sc.core.agent_pool import AgentPool
from sc import debug_log
async def run_single_task(
seed: int,
dataloader: DataLoader,
model_pool,
config: Dict[str, Any],
user_id: str = None,
) -> Dict[str, Any]:
"""
Execute a single classification task with Self-Consistency.
Args:
task: Task dictionary containing sample, examples, and metadata
model_pool: Async model pool for LLM inference
num_sc_samples: Number of Self-Consistency samples
Returns:
Result dictionary with answer, confidence, consistency, and metrics
"""
np.random.seed(seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
debug_log.warn_no_examples()
return []
# Build class_indices: Dict[str, List[int]] for Queue initialization
class_indices = {}
for idx, example in enumerate(example_dataset):
label = example["label"]
if label not in class_indices:
class_indices[label] = []
class_indices[label].append(idx)
ex_queue = Queue(class_indices, config["queue_size"])
sample_rate = config.get("sample_rate", 1)
results = []
cumulative_correct = 0
window_size = config.get("tracking_window", 10) # Recent window size for accuracy tracking
recent_results = [] # Windowed recent results for accuracy calculation
confidence_history = [] # Confidence history for tracking
processed_count = 0
for i, sample in enumerate(dataloader):
if sample_rate > 1 and i % sample_rate != 0:
continue
ex_queue.set_current_time(processed_count) # Track current sample index for survival time
user_info = f"User: {user_id}" if user_id else "Unknown User"
debug_log.log_queue_state_before(processed_count, user_info, ex_queue, config)
agent_pool = AgentPool(log_path=config["log_path"])
for queue_idx, ex_idcs in enumerate(ex_queue):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
task_info = dataloader.get_task_info()
sensor_info = dataloader.get_sensor_info()
classes_info = dataloader.get_classes_info()
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
model_pool=model_pool,
task_info=task_info,
classes_info=classes_info,
sensor_info=sensor_info,
sample=sample,
examples=examples,
log_path=config["log_path"],
)
agent_pool.add_agent(agent)
# Check if any agents were added
if len(agent_pool.agents) == 0:
debug_log.warn_no_agents(processed_count)
continue
interpretation_result = await agent_pool.run_parallel_interpretation()
# Handle case where interpretation failed
if interpretation_result is None:
debug_log.warn_interpretation_failed(processed_count)
continue
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
# Confidence history
confidence_history.append(avg_confidence)
avg_confidence_so_far = sum(confidence_history) / len(confidence_history)
debug_log.log_tracking(
processed_count,
user_info,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
avg_confidence_so_far,
config,
)
# Update queue based on Confidence (Priority Queue)
if responses:
# Create confidence map {queue_idx: confidence}
confidence_map = {idx: r.get("CONFIDENCE", 0) for idx, r in responses.items()}
ex_queue.update_by_confidence(confidence_map)
ex_queue.increment_usage(list(responses.keys()))
debug_log.log_queue_state_after(processed_count, ex_queue, config)
elif queue_idcs:
debug_log.warn_no_responses(processed_count)
result = {
"sample_idx": processed_count,
"answer": answer,
"ground_truth": ground_truth,
"is_correct": is_correct,
"confidence": avg_confidence,
"consistency": consistency,
"cumulative_accuracy": cumulative_accuracy,
"window_accuracy": window_accuracy,
"avg_confidence_so_far": avg_confidence_so_far,
}
results.append(result)
processed_count += 1
# Final Queue survival statistics
survival_summary = ex_queue.get_survival_summary()
debug_log.log_final_queue_stats(user_info, survival_summary, config)
# Add Queue statistics to results (in the last result)
if results:
results[-1]["queue_survival_stats"] = survival_summary
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
return results
async def run_parallel(
config: Dict[str, Any],
model_pool,
) -> List[Dict[str, Any]]:
"""
Execute all classification tasks in parallel.
Args:
tasks: List of task dictionaries
model_pool: Async model pool for LLM inference
config: Experiment configuration
Returns:
List of result dictionaries
"""
data_path = config["data_path"]
user_paths = glob(os.path.join(data_path, "*"))
# Filter to only include directories (exclude files like info.json)
users = [os.path.basename(p) for p in user_paths if os.path.isdir(p) and os.path.basename(p) != "info.json"]
if not users:
debug_log.warn_no_user_dirs(data_path)
return []
debug_log.log_found_users(users, config)
seeds = range(config.get("num_seeds", 1))
tasks = []
for user in users[:1]: # <- User Selection for testing
# for user in users:
for seed in seeds:
np.random.seed(seed)
dataloader = DataLoader(
data_path=data_path,
user_id=user,
example_pool=config.get("example_pool", "out"),
continuous=config.get("continuous", True),
)
# Check if dataloader was properly initialized
if not hasattr(dataloader, 'test_dataset') or len(dataloader) == 0:
debug_log.warn_skip_user_no_test_data(user)
continue
if len(dataloader.get_examples()) == 0:
debug_log.warn_skip_user_no_example_data(user)
continue
task = asyncio.create_task(
run_single_task(
seed,
dataloader,
model_pool,
config,
user_id=user,
)
)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
return results
# =============================================================================
# Statistics and Results
# =============================================================================
def compute_statistics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Compute experiment statistics from results.
Args:
results: List of result dictionaries
Returns:
Statistics dictionary containing accuracy, confidence, consistency metrics
"""
if not results:
return {}
# Overall accuracy
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Average confidence
confidences = [r.get("confidence", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
# Average consistency
consistencies = [r.get("consistency", 0) for r in results]
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-class accuracy
class_correct = {}
class_total = {}
for r in results:
gt = r.get("ground_truth", "UNKNOWN")
class_total[gt] = class_total.get(gt, 0) + 1
if r.get("is_correct", False):
class_correct[gt] = class_correct.get(gt, 0) + 1
class_accuracy = {
cls: class_correct.get(cls, 0) / class_total[cls]
for cls in class_total
}
# High consistency accuracy analysis
high_consistency_results = [r for r in results if r.get("consistency", 0) >= 0.8]
high_consistency_accuracy = (
sum(1 for r in high_consistency_results if r.get("is_correct", False))
/ len(high_consistency_results)
) if high_consistency_results else 0
# ================================================================
# TIME DEBUGGING
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = 0
second_half_acc = 0
improvement = 0
quartile_accs = []
quartile_size = len(results) // 4 if len(results) >= 4 else len(results)
if quartile_size > 0:
for q in range(4):
start = q * quartile_size
end = start + quartile_size if q < 3 else len(results)
quartile = results[start:end]
if quartile:
q_acc = sum(1 for r in quartile if r.get("is_correct", False)) / len(quartile)
quartile_accs.append(q_acc)
if len(confidences) > 1:
first_half_conf = np.mean(confidences[:mid_point]) if mid_point > 0 else 0
second_half_conf = np.mean(confidences[mid_point:]) if mid_point > 0 else 0
conf_improvement = second_half_conf - first_half_conf
else:
first_half_conf = avg_confidence
second_half_conf = avg_confidence
conf_improvement = 0
queue_survival_stats = {}
for r in results:
if "queue_survival_stats" in r:
queue_survival_stats = r["queue_survival_stats"]
break
# ================================================================
return {
"total_samples": total,
"correct": correct,
"accuracy": accuracy,
"avg_confidence": float(avg_confidence),
"avg_consistency": float(avg_consistency),
"class_accuracy": class_accuracy,
"high_consistency_accuracy": high_consistency_accuracy,
"high_consistency_samples": len(high_consistency_results),
"temporal_analysis": {
"first_half_accuracy": first_half_acc,
"second_half_accuracy": second_half_acc,
"accuracy_improvement": improvement,
"quartile_accuracies": quartile_accs,
"first_half_confidence": float(first_half_conf),
"second_half_confidence": float(second_half_conf),
"confidence_improvement": float(conf_improvement),
},
"queue_stats": queue_survival_stats,
}
def save_results(
results: List[Dict[str, Any]],
stats: Dict[str, Any],
config: Dict[str, Any]
) -> None:
"""
Save experiment results and statistics to files.
Args:
results: List of result dictionaries
stats: Statistics dictionary
config: Experiment configuration
"""
log_path = config["log_path"]
os.makedirs(log_path, exist_ok=True)
# Save statistics
stats_path = os.path.join(log_path, "statistics.json")
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
debug_log.log_save_statistics(stats_path, config)
# Save all results
results_to_save = []
for r in results:
r_copy = {k: v for k, v in r.items() if k != "all_responses"}
results_to_save.append(r_copy)
results_path = os.path.join(log_path, "all_results.json")
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results_to_save, f, indent=2, ensure_ascii=False)
debug_log.log_save_results(results_path, config)
# Save configuration for reproducibility
config_path = os.path.join(log_path, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
debug_log.log_save_config(config_path, config)
# =============================================================================
# CLI Commands
# =============================================================================
def main(config_path: str) -> None:
"""
Run experiment.
Args:
config_path: Path to YAML configuration file
Example:
python -m sc.run_sc sc/config/sleepedf_sc.yaml
"""
debug_log.log_main_loading_config(config_path)
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
# Add timestamp to log path for unique experiment runs
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if "log_path" in config:
config["log_path"] = f"{config['log_path']}_{timestamp}"
# Print experiment configuration
debug_log.log_main_config(config, config.get("debug", True))
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
max_new_tokens=config.get("max_new_tokens", 1024),
)
# Run experiments
debug_log.log_main_start(config)
all_results = asyncio.run(run_parallel(config, model_pool))
# Flatten results: run_parallel returns list of lists (one per user/seed)
# Each element is either a list of results or an exception
flattened_results = []
for result in all_results:
if isinstance(result, Exception):
debug_log.error_task_failed(result)
continue
if isinstance(result, list):
flattened_results.extend(result)
else:
flattened_results.append(result)
debug_log.log_total_results(len(flattened_results), config)
# Compute and display statistics
stats = compute_statistics(flattened_results)
debug_log.log_experiment_results(stats, config)
# Time-based analysis output
temporal = stats.get("temporal_analysis", {})
debug_log.log_temporal_analysis(temporal, config)
queue_stats = stats.get("queue_stats", {})
debug_log.log_queue_stats(queue_stats, config)
# Save results
save_results(flattened_results, stats, config)
debug_log.log_results_saved(config["log_path"], config)
if __name__ == "__main__":
Fire(main)

View File

@@ -1,557 +0,0 @@
"""
Queue Random Baseline Experiment Runner
This module implements the QUEUE RANDOM BASELINE policy:
- Queue structure is maintained (size=5)
- But every step, ALL 5 queue elements are replaced with fresh random samples
- This tests whether improvements come from queue structure itself
or from the cumulative learning effect
Purpose:
- Ablation study to distinguish:
1. Benefits from using 5 ICL examples (queue structure)
2. Benefits from retaining good examples over time (cumulative learning)
Key Difference from Other Policies:
- Confidence/Consistency: Queue updated by evicting lowest scoring examples
- Pure Random (no queue): Samples fresh examples each time, no structure
- Queue Random (this): Queue structure exists, but fully refreshed each step
Usage:
python -m sc.run_sc_queue_random --user_id=5 --shuffle_seed=42
python -m sc.run_sc_queue_random --user_id=10 --shuffle_seed=123
"""
import os
import sys
import asyncio
import yaml
import json
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Any, Optional
from sklearn.metrics import f1_score, precision_score, recall_score
from fire import Fire
# Add project root to path for relative imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.agent_pool import AgentPool
from sc import debug_log
log = debug_log.log
class QueueRandomSampler:
"""
Queue-based Random Sampler - maintains queue structure but refreshes all elements each step.
This sampler:
- Maintains a queue of size N (same as Confidence/Consistency policies)
- BUT replaces ALL N elements with fresh random samples at each step
- No memory/accumulation between steps
- Provides baseline to test if queue structure itself helps
"""
def __init__(self, example_dataset, queue_size: int = 5, seed: int = None):
"""
Initialize Queue Random sampler.
Args:
example_dataset: Dataset containing all available examples
queue_size: Size of the queue (number of ICL example sets)
seed: Random seed for reproducibility
"""
self.example_dataset = example_dataset
self.queue_size = queue_size
self.base_seed = seed
# Build class indices for balanced sampling
self.class_indices = {}
for idx, example in enumerate(example_dataset):
label = example["label"]
if label not in self.class_indices:
self.class_indices[label] = []
self.class_indices[label].append(idx)
self.classes = list(self.class_indices.keys())
self._step_counter = 0
# Initialize queue with random samples
self._queue = self._sample_all_new()
# Tracking for statistics (mimics Queue class interface)
self._total_refreshed = 0
debug_log.log_queue_random_sampler_init(
len(example_dataset),
self.classes,
queue_size,
)
def _sample_one_set(self) -> List[int]:
"""Sample one ICL example set (one example per class)."""
example_set = []
for cls in self.classes:
if self.class_indices[cls]:
idx = random.choice(self.class_indices[cls])
example_set.append(idx)
return example_set
def _sample_all_new(self) -> List[List[int]]:
"""Sample completely new queue contents."""
return [self._sample_one_set() for _ in range(self.queue_size)]
def refresh_all(self) -> None:
"""Replace ALL queue elements with fresh random samples."""
# Set seed for reproducibility but with variety per step
if self.base_seed is not None:
random.seed(self.base_seed + self._step_counter * 1000)
self._queue = self._sample_all_new()
self._total_refreshed += self.queue_size
self._step_counter += 1
def __iter__(self):
"""Iterate over queue contents (mimics Queue interface)."""
return iter(self._queue)
def __len__(self):
"""Return queue size."""
return len(self._queue)
def get_queue_state(self) -> List[List[int]]:
"""Get current queue state."""
return self._queue.copy()
def get_statistics(self) -> Dict[str, Any]:
"""Get sampling statistics."""
return {
"total_steps": self._step_counter,
"total_refreshed": self._total_refreshed,
"queue_size": self.queue_size,
"avg_refresh_per_step": self.queue_size, # Always full refresh
"policy": "queue_random",
}
async def run_queue_random_experiment(
dataloader: ShuffledDataLoader,
model_pool,
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> List[Dict[str, Any]]:
"""
Run Queue Random baseline experiment.
Queue Policy:
- Every step, ALL queue elements are replaced with fresh random samples
- No cumulative learning or retention of good examples
- Tests whether queue structure alone provides benefits
Args:
dataloader: ShuffledDataLoader with pre-shuffled test data
model_pool: Async model pool for LLM inference
config: Experiment configuration
user_id: User ID being processed
shuffle_seed: Shuffle seed for reproducibility
Returns:
List of result dictionaries
"""
# Set seeds for reproducibility
random.seed(shuffle_seed)
np.random.seed(shuffle_seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
log(f"[ERROR] No examples found for user {user_id}")
return []
# Initialize Queue Random Sampler (instead of regular Queue)
queue_size = config.get("queue_size", 5)
queue_sampler = QueueRandomSampler(
example_dataset=example_dataset,
queue_size=queue_size,
seed=shuffle_seed,
)
# Tracking variables
results = []
cumulative_correct = 0
window_size = config.get("tracking_window", 20)
recent_results = []
confidence_history = []
consistency_history = []
all_predictions = []
all_ground_truths = []
debug_log.log_policy_experiment_header(
"QUEUE RANDOM BASELINE",
user_id,
shuffle_seed,
len(dataloader),
queue_size,
policy_note=f"Policy: ALL {queue_size} queue elements refreshed EVERY step",
)
for processed_count, sample in enumerate(dataloader):
# CRITICAL: Refresh ALL queue elements before each inference
queue_sampler.refresh_all()
# Log queue state (all new random samples)
debug_log.log_queue_random_queue_state(processed_count, user_id, queue_sampler)
# Create agent pool
agent_pool = AgentPool(log_path=config["log_path"])
try:
for queue_idx, ex_idcs in enumerate(queue_sampler):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
model_pool=model_pool,
task_info=dataloader.get_task_info(),
classes_info=dataloader.get_classes_info(),
sensor_info=dataloader.get_sensor_info(),
sample=sample,
examples=examples,
log_path=config["log_path"],
)
agent_pool.add_agent(agent)
except Exception as e:
log(f"[ERROR] Failed to create agents: {e}")
import traceback
traceback.print_exc()
continue
if len(agent_pool.agents) == 0:
log(f"[WARN] No agents created for sample {processed_count}")
continue
# Run parallel interpretation
try:
interpretation_result = await agent_pool.run_parallel_interpretation()
except Exception as e:
log(f"[ERROR] Interpretation failed: {e}")
import traceback
traceback.print_exc()
continue
if interpretation_result is None:
log(f"[WARN] Interpretation failed for sample {processed_count}")
continue
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
# Evaluate result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
confidence_history.append(avg_confidence)
consistency_history.append(consistency)
all_predictions.append(answer)
all_ground_truths.append(ground_truth)
# Performance logging
debug_log.log_queue_random_result(
processed_count,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
)
# NO QUEUE UPDATE based on scores - just fresh random next time
# (This is the key difference from Confidence/Consistency policies)
# Store result
result = {
"sample_idx": processed_count,
"original_idx": sample.get("original_idx", processed_count),
"shuffled_idx": sample.get("shuffled_idx", processed_count),
"answer": answer,
"ground_truth": ground_truth,
"is_correct": is_correct,
"confidence": avg_confidence,
"consistency": consistency,
"cumulative_accuracy": cumulative_accuracy,
"window_accuracy": window_accuracy,
"experiment_type": "queue_random",
"user_id": user_id,
"shuffle_seed": shuffle_seed,
}
results.append(result)
# Final statistics
sampler_stats = queue_sampler.get_statistics()
debug_log.log_queue_random_stats(user_id, shuffle_seed, sampler_stats)
if results:
results[-1]["queue_random_stats"] = sampler_stats
return results
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
"""Compute comprehensive experiment statistics."""
if not results:
return {}
# Overall metrics
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Confidence and consistency averages
confidences = [r.get("confidence", 0) for r in results]
consistencies = [r.get("consistency", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-stage accuracy
stage_correct = {}
stage_total = {}
for r in results:
gt = r.get("ground_truth", "UNKNOWN")
stage_total[gt] = stage_total.get(gt, 0) + 1
if r.get("is_correct", False):
stage_correct[gt] = stage_correct.get(gt, 0) + 1
stage_accuracy = {}
for stage in stages:
if stage in stage_total:
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
else:
stage_accuracy[stage] = 0.0
# F1 Score and Macro metrics
predictions = [r.get("answer", "") for r in results]
ground_truths = [r.get("ground_truth", "") for r in results]
try:
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
except Exception:
macro_f1 = macro_precision = macro_recall = 0.0
# Temporal analysis
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = second_half_acc = accuracy
improvement = 0
# Learning curve (every 10 samples)
learning_curve = []
for i in range(0, len(results), 10):
chunk = results[:i+10]
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
learning_curve.append({
"sample_idx": min(i+10, len(results)),
"cumulative_accuracy": chunk_acc,
})
# Convergence speed
final_accuracy = accuracy
convergence_threshold = 0.9 * final_accuracy if final_accuracy > 0 else 0.5
convergence_idx = None
running_correct = 0
for i, r in enumerate(results):
running_correct += 1 if r.get("is_correct", False) else 0
running_acc = running_correct / (i + 1)
if running_acc >= convergence_threshold:
convergence_idx = i
break
# Queue random statistics
queue_stats = {}
for r in results:
if "queue_random_stats" in r:
queue_stats = r["queue_random_stats"]
break
return {
"experiment_type": "queue_random",
"user_id": results[0].get("user_id") if results else None,
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
"total_samples": total,
"correct": correct,
"accuracy": accuracy,
"avg_confidence": float(avg_confidence),
"avg_consistency": float(avg_consistency),
"stage_accuracy": stage_accuracy,
"stage_sample_counts": stage_total,
"macro_f1": float(macro_f1),
"macro_precision": float(macro_precision),
"macro_recall": float(macro_recall),
"temporal_analysis": {
"first_half_accuracy": first_half_acc,
"second_half_accuracy": second_half_acc,
"improvement": improvement,
},
"learning_curve": learning_curve,
"convergence_idx": convergence_idx,
"queue_stats": queue_stats,
}
def save_results(
results: List[Dict[str, Any]],
stats: Dict[str, Any],
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> None:
"""Save experiment results and statistics."""
log_path = config["log_path"]
# Create user/seed specific directory
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
os.makedirs(output_dir, exist_ok=True)
# Save statistics
stats_path = os.path.join(output_dir, "statistics.json")
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Statistics: {stats_path}")
# Save results
results_path = os.path.join(output_dir, "results.json")
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Results: {results_path}")
# Save config
config_path = os.path.join(output_dir, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
log(f"[SAVE] Config: {config_path}")
def main(
config_path: str = "sc/config/experiment_queue_random.yaml",
user_id: int = 5,
shuffle_seed: int = 42,
) -> None:
"""
Run Queue Random baseline experiment.
This baseline maintains queue structure but refreshes ALL elements each step.
Tests whether performance gains come from:
1. Queue structure itself (5 ICL examples)
2. Cumulative learning (retaining good examples over time)
Args:
config_path: Path to YAML configuration file
user_id: User ID to process (5, 10, or 15)
shuffle_seed: Shuffle seed for data order (42 or 123)
Example:
python -m sc.run_sc_queue_random --user_id=5 --shuffle_seed=42
python -m sc.run_sc_queue_random --user_id=15 --shuffle_seed=123
"""
log(f"[MAIN] Loading config: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
# Override with CLI arguments
config["user_id"] = user_id
config["shuffle_seed"] = shuffle_seed
# Create unique log path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
os.makedirs(config["log_path"], exist_ok=True)
# Print experiment info
debug_log.log_policy_main_header(
"QUEUE RANDOM BASELINE",
user_id,
shuffle_seed,
config.get("queue_size", 5),
len(config.get("models", [])),
config["log_path"],
policy_note=" Policy: Queue Random (full refresh every step)",
)
# Load shuffled data
debug_log.log_policy_loading_data(user_id, shuffle_seed)
dataloader = ShuffledDataLoader(
data_path=config["data_path"],
user_id=user_id,
seed=shuffle_seed,
example_pool=config.get("example_pool", "out"),
)
# Load models
debug_log.log_policy_loading_models()
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
max_new_tokens=config.get("max_new_tokens", 1024),
)
# Run experiment
debug_log.log_policy_start(label="Queue Random experiment")
results = asyncio.run(run_queue_random_experiment(
dataloader=dataloader,
model_pool=model_pool,
config=config,
user_id=user_id,
shuffle_seed=shuffle_seed,
))
# Compute statistics
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
stats = compute_statistics(results, stages)
# Print final summary
temporal = stats.get("temporal_analysis", {})
debug_log.log_policy_complete_summary(
"QUEUE RANDOM BASELINE",
user_id,
shuffle_seed,
stats,
stats.get("stage_accuracy", {}),
stats.get("stage_sample_counts", {}),
temporal,
expected_note="\n [EXPECTED] Improvement should be ~0 (no cumulative learning)",
)
# Save results
save_results(results, stats, config, user_id, shuffle_seed)
log(f"\n[MAIN] Results saved to: {config['log_path']}")
if __name__ == "__main__":
Fire(main)

View File

@@ -1,412 +0,0 @@
"""
Consistency-based Queue Policy Experiment Runner
This module implements the CONSISTENCY-based queue update policy:
- Queue is updated based on SC consensus/agreement scores
- Higher consistency (more agents agree) examples are retained
- Lower consistency examples are evicted
Consistency Score = (Number of agents agreeing with majority) / (Total agents)
Usage:
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
"""
import os
import sys
import asyncio
import yaml
import json
import random
import numpy as np
from datetime import datetime
from typing import List, Dict, Any, Optional
from sklearn.metrics import f1_score, precision_score, recall_score
from fire import Fire
# Add project root to path for relative imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.data_loader import DataLoader
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.example_queue import Queue
from sc.core.agent_pool import AgentPool
from sc.logger import Logger
from sc import debug_log
log = debug_log.log
async def run_consistency_experiment(
dataloader: ShuffledDataLoader,
model_pool,
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> List[Dict[str, Any]]:
"""
Run consistency-based queue policy experiment.
Queue Update Policy:
- After each inference, calculate per-agent consistency
- Consistency = how many other agents agree with this agent's answer
- Rank queue elements by consistency score
- Evict the lowest consistency element
- Add a new random ICL example set
Args:
dataloader: ShuffledDataLoader with pre-shuffled test data
model_pool: Async model pool for LLM inference
config: Experiment configuration
user_id: User ID being processed
shuffle_seed: Shuffle seed for reproducibility
Returns:
List of result dictionaries
"""
# Set seeds for reproducibility within experiment
random.seed(shuffle_seed)
np.random.seed(shuffle_seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
log(f"[ERROR] No examples found for user {user_id}")
return []
# Build class_indices for Queue initialization
class_indices = {}
for idx, example in enumerate(example_dataset):
label = example["label"]
if label not in class_indices:
class_indices[label] = []
class_indices[label].append(idx)
# Initialize Queue
queue_size = config.get("queue_size", 5)
ex_queue = Queue(class_indices, queue_size)
# Tracking variables
results = []
cumulative_correct = 0
window_size = config.get("tracking_window", 20)
recent_results = []
confidence_history = []
consistency_history = []
all_predictions = []
all_ground_truths = []
debug_log.log_policy_experiment_header(
"CONSISTENCY-BASED",
user_id,
shuffle_seed,
len(dataloader),
queue_size,
)
for processed_count, sample in enumerate(dataloader):
ex_queue.set_current_time(processed_count)
# Log queue state before processing
debug_log.log_policy_queue_state(
"CONSISTENCY", processed_count, user_id, ex_queue
)
# Create agent pool
agent_pool = AgentPool(log_path=config["log_path"])
try:
for queue_idx, ex_idcs in enumerate(ex_queue):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
model_pool=model_pool,
task_info=dataloader.get_task_info(),
classes_info=dataloader.get_classes_info(),
sensor_info=dataloader.get_sensor_info(),
sample=sample,
examples=examples,
log_path=config["log_path"],
)
agent_pool.add_agent(agent)
except Exception as e:
log(f"[ERROR] Failed to create agents: {e}")
import traceback
traceback.print_exc()
continue
if len(agent_pool.agents) == 0:
log(f"[WARN] No agents created for sample {processed_count}")
continue
# Run parallel interpretation
try:
interpretation_result = await agent_pool.run_parallel_interpretation()
except Exception as e:
log(f"[ERROR] Interpretation failed: {e}")
import traceback
traceback.print_exc()
continue
if interpretation_result is None:
log(f"[WARN] Interpretation failed for sample {processed_count}")
continue
answer, queue_idcs, avg_confidence, consistency, responses = (
interpretation_result
)
# Evaluate result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
confidence_history.append(avg_confidence)
consistency_history.append(consistency)
all_predictions.append(answer)
all_ground_truths.append(ground_truth)
# Performance logging
debug_log.log_policy_result(
"CONSISTENCY",
processed_count,
answer,
ground_truth,
is_correct,
avg_confidence,
consistency,
cumulative_accuracy,
cumulative_correct,
window_accuracy,
recent_results,
)
# CONSISTENCY-BASED Queue Update
if responses:
# Calculate per-agent consistency: how many other agents agree
all_answers = [r.get("ANSWER") for r in responses.values()]
consistency_map = {}
for idx, response in responses.items():
agent_answer = response.get("ANSWER")
# Consistency = ratio of agents (including self) that agree
agreement_count = all_answers.count(agent_answer)
agent_consistency = (
agreement_count / len(all_answers) if all_answers else 0
)
consistency_map[idx] = agent_consistency
debug_log.log_consistency_map(responses, consistency_map)
# Update queue by consistency (reuses confidence method with consistency scores)
ex_queue.update_by_confidence(consistency_map)
ex_queue.increment_usage(list(responses.keys()))
debug_log.log_policy_queue_state_after(
"Consistency", processed_count, ex_queue
)
# Store result
result = {
"sample_idx": processed_count,
"original_idx": sample.get("original_idx", processed_count),
"shuffled_idx": sample.get("shuffled_idx", processed_count),
"answer": answer,
"ground_truth": ground_truth,
"is_correct": is_correct,
"confidence": avg_confidence,
"consistency": consistency,
"cumulative_accuracy": cumulative_accuracy,
"window_accuracy": window_accuracy,
"experiment_type": "usc",
"user_id": user_id,
"shuffle_seed": shuffle_seed,
}
results.append(result)
# Final statistics
survival_summary = ex_queue.get_survival_summary()
debug_log.log_final_policy_survival("usc", user_id, shuffle_seed, survival_summary)
if results:
results[-1]["queue_survival_stats"] = survival_summary
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
return results
def compute_statistics(
results: List[Dict[str, Any]], stages: List[str]
) -> Dict[str, Any]:
"""Compute comprehensive experiment statistics."""
if not results:
return {}
# Overall metrics
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Confidence and consistency averages
confidences = [r.get("confidence", 0) for r in results]
consistencies = [r.get("consistency", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-stage accuracy
stage_correct = {}
stage_total = {}
for r in results:
gt = r.get("ground_truth", "UNKNOWN")
stage_total[gt] = stage_total.get(gt, 0) + 1
if r.get("is_correct", False):
stage_correct[gt] = stage_correct.get(gt, 0) + 1
stage_accuracy = {}
for stage in stages:
if stage in stage_total:
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
else:
stage_accuracy[stage] = 0.0
# F1 Score and Macro metrics
predictions = [r.get("answer", "") for r in results]
ground_truths = [r.get("ground_truth", "") for r in results]
try:
macro_f1 = f1_score(
ground_truths, predictions, average="macro", zero_division=0
)
macro_precision = precision_score(
ground_truths, predictions, average="macro", zero_division=0
)
macro_recall = recall_score(
ground_truths, predictions, average="macro", zero_division=0
)
except Exception:
macro_f1 = macro_precision = macro_recall = 0.0
# Temporal analysis
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(
first_half
)
second_half_acc = sum(
1 for r in second_half if r.get("is_correct", False)
) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = second_half_acc = accuracy
improvement = 0
# Learning curve (every 10 samples)
learning_curve = []
for i in range(0, len(results), 10):
chunk = results[: i + 10]
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
learning_curve.append(
{
"sample_idx": min(i + 10, len(results)),
"cumulative_accuracy": chunk_acc,
}
)
# Convergence speed (90% of final accuracy)
final_accuracy = accuracy
convergence_threshold = 0.9 * final_accuracy
convergence_idx = None
running_correct = 0
for i, r in enumerate(results):
running_correct += 1 if r.get("is_correct", False) else 0
running_acc = running_correct / (i + 1)
if running_acc >= convergence_threshold:
convergence_idx = i
break
# Queue statistics
queue_stats = {}
for r in results:
if "queue_survival_stats" in r:
queue_stats = r["queue_survival_stats"]
break
return {
"experiment_type": "consistency",
"user_id": results[0].get("user_id") if results else None,
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
"total_samples": total,
"correct": correct,
"accuracy": accuracy,
"avg_confidence": float(avg_confidence),
"avg_consistency": float(avg_consistency),
"stage_accuracy": stage_accuracy,
"stage_sample_counts": stage_total,
"macro_f1": float(macro_f1),
"macro_precision": float(macro_precision),
"macro_recall": float(macro_recall),
"temporal_analysis": {
"first_half_accuracy": first_half_acc,
"second_half_accuracy": second_half_acc,
"improvement": improvement,
},
"learning_curve": learning_curve,
"convergence_idx": convergence_idx,
"queue_stats": queue_stats,
}
def main(config_path: str) -> None:
"""
Args:
config_path: Path to YAML configuration file
"""
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
logger = Logger(config["log_path"])
logger.log_config(config)
dataloader = DataLoader(
config["data_path"],
config["user_id"],
shuffle=config["shuffle"],
seed=config["seed"],
)
model_pool = load_models(
config["models"],
temperature=config["temperature"],
max_new_tokens=config.get("max_new_tokens", 1024),
)
results = asyncio.run(
run_consistency_experiment(
dataloader=dataloader,
model_pool=model_pool,
config=config,
user_id=user_id,
shuffle_seed=shuffle_seed,
)
)
if __name__ == "__main__":
Fire(main)

View File

@@ -1,124 +0,0 @@
"""
Call transformers serve (chat) and our HF API (logprobs) from Python with requests.
Summary:
--------
- transformers serve (transformers chat localhost:8000 --model-name-or-path ...):
- Exposes OpenAI-compatible endpoints: /v1/chat/completions, /v1/responses, /v1/models.
- It does NOT return logits or logprobs; the response chunks have "logprobs": null.
- You can still use it for chat from Python via requests (see chat_with_transformers_serve).
- For answer + logprobs (top-k per token):
- Use our custom API in sc/hf_api.py:
MODEL_DIR=/path/to/model uvicorn sc.hf_api:app --host 0.0.0.0 --port 8000
- Then call POST /generate with requests (see get_logprobs_via_hf_api).
Usage:
------
# Chat only (transformers serve on 8000):
python -c "
from sc.transformers_serve_client_example import chat_with_transformers_serve
print(chat_with_transformers_serve('Hello!'))
"
# Answer + logprobs (sc/hf_api on 8000):
python -c "
from sc.transformers_serve_client_example import get_logprobs_via_hf_api
r = get_logprobs_via_hf_api([{'role':'user','content':'Hello!'}])
print('text:', r['generated_text'])
print('first step top-k:', r['steps'][0]['topk'])
"
"""
from __future__ import annotations
import json
from typing import Any, Dict, List
import requests
# Default base URL for transformers serve (OpenAI-compatible)
TRANSFORMERS_SERVE_URL = "http://localhost:8000/v1"
# Default base URL for our sc/hf_api (generate + logprobs)
HF_API_URL = "http://localhost:8000"
def chat_with_transformers_serve(
user_message: str,
*,
base_url: str = TRANSFORMERS_SERVE_URL,
model: str = "openai/gpt-oss-20b",
max_tokens: int = 256,
stream: bool = False,
) -> str:
"""
Send a chat message to a server running `transformers serve`.
Returns the assistant reply text. No logits/logprobs (server does not provide them).
"""
url = f"{base_url.rstrip('/')}/chat/completions"
payload = {
"model": model,
"messages": [{"role": "user", "content": user_message}],
"max_tokens": max_tokens,
"stream": stream,
}
resp = requests.post(url, json=payload, timeout=60)
resp.raise_for_status()
data = resp.json()
# Non-stream: choices[0].message.content
choices = data.get("choices", [])
if not choices:
return ""
msg = choices[0].get("message", {})
return msg.get("content", "") or ""
def get_logprobs_via_hf_api(
messages: List[Dict[str, str]],
*,
base_url: str = HF_API_URL,
max_new_tokens: int = 64,
top_k: int = 10,
) -> Dict[str, Any]:
"""
Call our sc/hf_api POST /generate endpoint.
Returns generated text and per-token top-k logprobs (no raw logits over the wire).
"""
url = f"{base_url.rstrip('/')}/generate"
payload = {
"messages": messages,
"max_new_tokens": max_new_tokens,
"top_k": top_k,
"temperature": 0.0,
"do_sample": False,
}
resp = requests.post(url, json=payload, timeout=120)
resp.raise_for_status()
return resp.json()
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python -m sc.transformers_serve_client_example chat|logprobs [message]")
print(" chat -> call transformers serve /v1/chat/completions (no logprobs)")
print(" logprobs -> call sc/hf_api /generate (returns top-k logprobs)")
sys.exit(0)
cmd = sys.argv[1].lower()
message = (sys.argv[2] if len(sys.argv) > 2 else "Hello, how are you?").strip()
if cmd == "chat":
text = chat_with_transformers_serve(message)
print("Reply:", text)
elif cmd == "logprobs":
out = get_logprobs_via_hf_api([{"role": "user", "content": message}])
print("Generated:", out.get("generated_text", ""))
print("Steps (first 3):")
for s in out.get("steps", [])[:3]:
print(" token:", repr(s.get("token")), "logprob:", s.get("logprob"), "topk:", [t.get("token") for t in s.get("topk", [])[:5]])
else:
print("Unknown command. Use 'chat' or 'logprobs'.")
sys.exit(1)