5 Commits
main ... case

Author SHA1 Message Date
AadharshAadhithya
216bd2ecc3 case init 2026-03-10 20:31:05 -05:00
AadharshAadhithya
df5e6db8b6 ff 2026-02-13 19:59:42 -06:00
Hynugjun Yoon
31f4af7106 implemented new version of llm initialization 2026-02-12 17:42:09 +09:00
Hyungjun Yoon
8a26346b5e updated model to take hf format 2026-02-12 16:52:46 +09:00
Hyungjun Yoon
a7c8e43f89 updated model to huggingface framework 2026-02-12 14:42:25 +09:00
25 changed files with 5477 additions and 540 deletions

198
.claude/CLAUDE.md Normal file
View File

@@ -0,0 +1,198 @@
# Project Memory: tsllm_personalization_icl
## Project Overview
Research project on **personalized time-series classification using LLMs with In-Context Learning (ICL)**. The core idea is to use LLMs (via Ollama or HuggingFace) to classify sensor/physiological time-series data (e.g., EEG sleep stages, PPG blood pressure), where ICL examples are selected via different strategies (random vs. similarity-based) to study personalization.
Current branch: `case` (main branch: `main`)
---
## Repository Structure
```
tsllm_personalization_icl/
├── run.py # Main entry point (ICL experiment runner)
├── config/
│ └── sleepedf.yaml # Config: data path, models, selection criteria, log path
├── core/ # Base pipeline (ICL experiments)
│ ├── model.py # Model/AsyncModelPool (Ollama or LangChain init_chat_model)
│ ├── agent.py # Base Agent class (memory, logging, JSON parsing)
│ ├── sensing_agent.py # SensingAgent: ICL classify + evaluate + reflect
│ ├── data_loader.py # DataLoader: test/example splits + similarity-based selection
│ └── embedding_index.py # EmbeddingIndex: cosine similarity over Chronos-2 embeddings
├── sc/ # Self-Consistency (SC) variant pipeline
│ ├── run_sc.py # SC experiment runner (main entry: `python -m sc.run_sc config.yaml`)
│ ├── run_confidence_based.py
│ ├── run_consistency_based.py
│ ├── run_sc_queue_random.py
│ ├── run_usc.py
│ ├── core/
│ │ ├── model.py # HuggingFace CausalLM wrapper (returns text + logits)
│ │ ├── agent.py # SC base agent
│ │ ├── scagent.py # SCAgent: single-pass interpret() with REASON/CONFIDENCE/ANSWER
│ │ ├── agent_pool.py # AgentPool: parallel interpret + majority voting
│ │ ├── example_queue.py # Queue: priority queue of example sets, updated by confidence
│ │ ├── data_loader.py # DataLoader + InMemoryDataLoader + PPGBPLoader
│ │ ├── judge_agent.py # JudgeAgent
│ │ ├── majority_voting.py # MajorityVoting utilities
│ │ └── model_utils.py
│ ├── analysis/
│ │ └── analyze_sc_results.py
│ ├── preprocess/
│ │ └── shuffle_data.py
│ ├── debug_log.py # Structured debug logging helpers
│ ├── logger.py
│ ├── hf_api.py
│ └── ollama_test.py
├── analysis/
│ ├── ppgbp_loader.py
│ ├── user_similarity/ # Embedding analysis scripts
│ │ ├── chronos2/ # Chronos-2 embeddings for SleepEDF
│ │ ├── chronos2_ppgbp/ # Chronos-2 embeddings for PPGBP
│ │ ├── labram/ # LaBraM embeddings
│ │ ├── sbert/
│ │ ├── sbert_metadata/
│ │ └── sbert_metadata_ppgbp/ # SBERT metadata embeddings for PPGBP
│ ├── analyze_data.ipynb
│ └── analyze_preliminary.ipynb
├── preprocess/
│ ├── dhedfreader.py
│ └── preprocess_SleepEDF.py
├── utils/
│ ├── kill_ollamas.sh
│ └── launch_ollamas.sh
└── requirements.txt
```
---
## Key Concepts
### Datasets
- **SleepEDF**: EEG-based sleep stage classification (W, N1, N2, N3, REM).
- **PPGBP**: PPG-based blood pressure prediction.
- Data format: `<data_path>/<user_id>/1/` (train/examples, HuggingFace dataset on disk) and `<data_path>/<user_id>/2/` (test split). Metadata in `info.json` (keys: `task`, `class`, `feature`).
### ICL Selection Strategies (core/data_loader.py)
- `out_random`: Random examples from OTHER users (cross-user, random)
- `in_random`: Random examples from SAME user (personalized, random)
- `out_similar`: Most similar examples from OTHER users (Chronos-2 embedding similarity)
- `in_similar`: Most similar examples from SAME user (embedding similarity)
Similarity uses cosine similarity over Chronos-2 embeddings, one example per class (balanced).
### Models
- **core/model.py**: Supports Ollama (local/remote via `ollama:url:<host>/<model>`) and any LangChain-supported model (OpenAI, Together, etc.)
- **sc/core/model.py**: HuggingFace CausalLM loaded via `transformers`, returns text + logits. Used for SC experiments.
### Self-Consistency (sc/) Pipeline
The SC pipeline uses a **priority queue of example sets**:
- `Queue` (sc/core/example_queue.py): holds `capacity` example-sets (one set = one example per class). Updated each step by agent confidence scores — highest-confidence sets are kept, lowest evicted and replaced with a random new set.
- `AgentPool` (sc/core/agent_pool.py): runs one `SCAgent` per queue slot in parallel, aggregates via majority vote, tracks confidence and consistency.
- `SCAgent` (sc/core/scagent.py): single LLM call, returns `{REASON, CONFIDENCE, ANSWER}` JSON.
### Base Agent (core/agent.py)
- Three memory tiers: `long_term_memory`, `short_term_memory`, `volatile_memory`
- `invoke()`: async, appends messages to memory, calls model pool
- `safe_parse_json()` / `safe_parse_json_list()`: robust JSON parsing with cleanup
- Token counting via `tiktoken` (gpt-3.5-turbo encoding as proxy)
### SensingAgent (core/sensing_agent.py)
Extends Agent. Methods:
- `solve()`: classify a sample with ICL examples → `{REASON, ANSWER}`
- `interpret()`: same as solve but without logging ground truth
- `evaluate()`: evaluate another agent's answer (for multi-agent debate)
- `reflect()`: refine answer based on peer evaluations
---
## Running Experiments
### Basic ICL run
```bash
python run.py run config/sleepedf.yaml
```
### Compare multiple selection criteria
```bash
python run.py compare config/sleepedf.yaml \
--criteria_list="out_random,in_random,out_similar,in_similar" \
--embedding_path="./embeddings_full"
```
### Self-Consistency run
```bash
python -m sc.run_sc sc/config/sleepedf_sc.yaml
```
---
## Config (sleepedf.yaml) Key Fields
- `data_path`: root of dataset (contains `info.json` and per-user dirs)
- `log_path`: output directory for results and logs
- `num_seeds`: number of random seeds to run
- `num_examples`: ICL examples per class (default 1)
- `selection_criteria`: `out_random` | `in_random` | `out_similar` | `in_similar`
- `embedding_path`: path to pre-computed Chronos-2 embeddings (required for `*_similar`)
- `models`: list of model specs (Ollama URLs or model IDs)
In SC configs, additional fields:
- `queue_size`: number of example sets to maintain in the priority queue
- `temperature`: LLM sampling temperature
- `max_new_tokens`: max tokens for generation
- `example_pool`: `"out"` (other users) or `"in"` (same user)
- `continuous`: whether queue persists across test samples
---
## Data Format Details
Each HuggingFace dataset sample:
```python
{
"user_id": str,
"session_id": str, # "1" = train, "2" = test
"idx": int,
"label": str, # class name from info.json["class"]
"features": dict, # str -> float/str (sensor features formatted for prompt)
"data": dict, # optional raw data
}
```
`info.json`:
```json
{
"task": "Sleep stage classification from EEG",
"class": {"W": "Wake", "N1": "...", ...},
"feature": "EEG channel description for the LLM..."
}
```
---
## PPGBP Dataset (sc/core/data_loader.py)
`PPGBPLoader`: reads xlsx metadata + signal `.txt` files from `0_subject/`. 80/20 train/test split at subject level. Metadata embeddings (SBERT) auto-generated and cached per subject as `.npy` files under `PPGBP_METADATA_EMBEDDINGS_ROOT` (set via env var or `.env`).
---
## Dependencies (requirements.txt)
- `langchain`, `langchain_ollama`, `langchain_openai`, `langchain_together` — LLM backends
- `datasets` — HuggingFace datasets (on-disk storage)
- `chronos` — Chronos-2 time-series embeddings
- `sentence_transformers` — SBERT for metadata embeddings
- `transformers`, `torch` — HuggingFace model loading (SC pipeline)
- `tiktoken` — token counting
- `fire` — CLI argument parsing
- `mne`, `neurokit2` — EEG/biosignal preprocessing
- `numpy`, `pandas`, `scikit_learn`, `scipy`, `matplotlib`
---
## Notes / Patterns
- Experiments sample every 10th test item (`if idx % 10 != 0: continue` in `run.py`).
- SC runner currently hardcoded to first user only for testing (`users[:1]`).
- Log structure: `<log_path>/<user_id>/<sample_idx>/<seed>/` with `summary.txt`, `log.txt`, `tokens.txt`.
- Embedding index uses cosine similarity (L2-normalized dot product), filtered by user and session.
- `DataLoader` in `sc/` is simpler (no similarity selection); similarity lives in `core/`.
- `InMemoryDataLoader` and `prepare_dataset_for_sc()` allow integrating new datasets without writing to disk.

0
.claude/plan.md Normal file
View File

235
analysis/ppgbp_loader.py Normal file
View File

@@ -0,0 +1,235 @@
"""
PPGBP Dataset Loader
A data loader/iterator for the PPG-BP dataset that reads xlsx metadata
and corresponding signal text files, with optional batching support.
"""
import os
import numpy as np
import pandas as pd
from typing import Iterator, Tuple, Dict, List, Optional, Union
class PPGBPLoader:
"""
Data loader for the PPG-BP dataset.
Reads metadata from xlsx file and corresponding PPG signal text files.
Supports iteration over single samples or batches.
Subject list is built from 0_subject/ and matched with metadata;
train/test is 80/20 at subject level (random, fixed by seed).
"""
COLUMN_MAPPING = {
"Sex(M/F)": "sex",
"Age(year)": "age",
"Systolic Blood Pressure(mmHg)": "sysbp",
"Diastolic Blood Pressure(mmHg)": "diasbp",
"Heart Rate(b/m)": "hr",
"BMI(kg/m^2)": "bmi"
}
def __init__(
self,
base_dir: str,
split: str = 'all',
batch_size: Optional[int] = None,
shuffle: bool = False,
num_segments: int = 3,
seed: int = 42
):
self.base_dir = base_dir
self.split = split
self.batch_size = batch_size
self.shuffle = shuffle
self.num_segments = num_segments
self.seed = seed
self.rng = np.random.default_rng(seed)
self.signal_dir = os.path.join(base_dir, "0_subject")
self.xlsx_path = self._find_xlsx_file()
self.metadata = self._load_metadata()
self._all_subject_ids = self._get_unified_subject_ids()
self._train_ids, self._test_ids = self._get_train_test_split(self._all_subject_ids)
self.subject_ids = self._get_split_ids()
self.metadata = self.metadata[self.metadata['subject_ID'].isin(self.subject_ids)]
self._current_idx = 0
self._indices = np.arange(len(self.subject_ids))
if self.shuffle:
self.rng.shuffle(self._indices)
def _find_xlsx_file(self) -> str:
for f in os.listdir(self.base_dir):
if f.endswith('.xlsx'):
return os.path.join(self.base_dir, f)
raise FileNotFoundError(f"No xlsx file found in {self.base_dir}")
def _load_metadata(self) -> pd.DataFrame:
# Use header=1: xlsx has a title row, then row with column names (subject_ID, etc.)
df = pd.read_excel(self.xlsx_path, header=1)
df = df.rename(columns=self.COLUMN_MAPPING)
df = df.fillna(0)
return df
def _get_unified_subject_ids(self) -> np.ndarray:
"""Subject IDs present in both 0_subject/ (from listdir) and metadata."""
if not os.path.isdir(self.signal_dir):
return np.array([], dtype=np.int64)
# Files are named {subject_id}_{segment}.txt
ids_from_files = set()
for f in os.listdir(self.signal_dir):
if f.endswith(".txt"):
try:
sid = int(f.replace(".txt", "").split("_")[0])
ids_from_files.add(sid)
except (ValueError, IndexError):
continue
meta_ids = set(self.metadata["subject_ID"].astype(int).values)
unified = sorted(ids_from_files & meta_ids)
return np.array(unified)
def _get_train_test_split(self, ids: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""80% train, 20% test at subject level; shuffle fixed by seed."""
if len(ids) == 0:
return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
perm = self.rng.permutation(len(ids))
ids_perm = ids[perm]
n_train = max(1, int(round(0.8 * len(ids))))
train_ids = ids_perm[:n_train]
test_ids = ids_perm[n_train:]
return train_ids, test_ids
def _get_split_ids(self) -> np.ndarray:
if self.split == "train":
return self._train_ids
elif self.split == "test":
return self._test_ids
elif self.split == "all":
return self._all_subject_ids
else:
raise ValueError(f"Unknown split: {self.split}. Use 'train', 'test', or 'all'.")
@property
def train_ids(self) -> np.ndarray:
"""Subject IDs in the train split (80%)."""
return self._train_ids
@property
def test_ids(self) -> np.ndarray:
"""Subject IDs in the test split (20%)."""
return self._test_ids
def _load_signal(self, subject_id: int) -> np.ndarray:
segments = []
for s in range(1, self.num_segments + 1):
filepath = os.path.join(self.signal_dir, f"{subject_id}_{s}.txt")
signal = pd.read_csv(filepath, sep='\t', header=None)
signal = signal.values.squeeze()
if len(signal) > 1:
signal = signal[:-1]
segments.append(signal)
return np.array(segments, dtype=object)
def _get_metadata_dict(self, subject_id: int) -> Dict:
row = self.metadata[self.metadata['subject_ID'] == subject_id].iloc[0]
return row.to_dict()
def __len__(self) -> int:
return len(self.subject_ids)
def __iter__(self) -> 'PPGBPLoader':
self._current_idx = 0
if self.shuffle:
self.rng.shuffle(self._indices)
return self
def __next__(self) -> Union[Tuple[Dict, np.ndarray], Tuple[List[Dict], List[np.ndarray]]]:
if self._current_idx >= len(self.subject_ids):
raise StopIteration
if self.batch_size is None:
idx = self._indices[self._current_idx]
subject_id = self.subject_ids[idx]
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
self._current_idx += 1
return metadata, signal
else:
end_idx = min(self._current_idx + self.batch_size, len(self.subject_ids))
batch_indices = self._indices[self._current_idx:end_idx]
metadata_list = []
signal_list = []
for idx in batch_indices:
subject_id = self.subject_ids[idx]
metadata_list.append(self._get_metadata_dict(subject_id))
signal_list.append(self._load_signal(subject_id))
self._current_idx = end_idx
return metadata_list, signal_list
def __getitem__(self, idx: int) -> Tuple[Dict, np.ndarray]:
if idx < 0 or idx >= len(self.subject_ids):
raise IndexError(f"Index {idx} out of range")
subject_id = self.subject_ids[idx]
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
return metadata, signal
def get_by_subject_id(self, subject_id: int) -> Tuple[Dict, np.ndarray]:
if subject_id not in self.subject_ids:
raise ValueError(f"Subject ID {subject_id} not found in this split.")
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
return metadata, signal
def reset(self):
self._current_idx = 0
if self.shuffle:
self.rng.shuffle(self._indices)
def iter_batches(self, batch_size: int) -> Iterator[Tuple[List[Dict], List[np.ndarray]]]:
indices = np.arange(len(self.subject_ids))
if self.shuffle:
self.rng.shuffle(indices)
for start_idx in range(0, len(self.subject_ids), batch_size):
end_idx = min(start_idx + batch_size, len(self.subject_ids))
batch_indices = indices[start_idx:end_idx]
metadata_list = []
signal_list = []
for idx in batch_indices:
subject_id = self.subject_ids[idx]
metadata_list.append(self._get_metadata_dict(subject_id))
signal_list.append(self._load_signal(subject_id))
yield metadata_list, signal_list
def get_loaders(
base_dir: str,
batch_size: Optional[int] = None,
shuffle_train: bool = True,
seed: int = 42
) -> Tuple[PPGBPLoader, PPGBPLoader]:
"""Convenience function to get train and test loaders (80/20 split)."""
train_loader = PPGBPLoader(
base_dir=base_dir, split="train",
batch_size=batch_size, shuffle=shuffle_train, seed=seed
)
test_loader = PPGBPLoader(
base_dir=base_dir, split="test",
batch_size=batch_size, shuffle=False, seed=seed
)
return train_loader, test_loader
if __name__ == "__main__":
print("PPGBPLoader ready to use.")

View File

@@ -0,0 +1,982 @@
"""
Chronos-2 PPG-BP Embedding Extraction and Visualization Pipeline
This module provides functionality to:
1. Extract embeddings from univariate PPG signals (PPG-BP dataset) using Chronos-2
2. Visualize embeddings using t-SNE colored by various metadata categories
Chronos-2 Overview:
Chronos-2 is a time series foundation model developed by Amazon that converts
time series forecasting into a language modeling task. It tokenizes time series
values and generates probabilistic forecasts using a transformer architecture.
Embedding Strategy:
We use Chronos-2's internal encoder hidden states as embeddings. For univariate
PPG signals, each signal is treated as a single-variate time series, producing
a 512-dimensional embedding after pooling.
Processing Pipeline:
1. Load PPG signals via PPGBPLoader (3 segments per subject)
2. Each segment is embedded independently through Chronos-2 encoder
3. Encoder hidden states are pooled to produce a fixed-size embedding
4. Metadata from the dataset xlsx is attached for visualization
Usage:
# Extract embeddings
python gen_plot.py extract --base_dir /path/to/ppgbp/DataFile --out_dir ./embeddings
# Visualize with t-SNE
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
# LazyPredict: predict hypertension (or other target) from Chronos embeddings
python gen_plot.py lazy_eval --emb_dir ./embeddings --target hypertension --out_csv ./lazy_hypertension.csv
"""
import os
import sys
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from datasets import Dataset, load_from_disk
from chronos import Chronos2Pipeline
from fire import Fire
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder
try:
from lazypredict.Supervised import LazyClassifier, LazyRegressor
HAS_LAZYPREDICT = True
except ImportError:
HAS_LAZYPREDICT = False
# Add parent directories to path for importing PPGBPLoader
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from ppgbp_loader import PPGBPLoader
# Target column options for lazy_eval: classification vs regression
LAZY_EVAL_CLASSIFICATION_TARGETS = [
"hypertension",
"sex",
"age_category",
"height_category",
"weight_category",
"systolic_category",
"diastolic_category",
"heart_rate_category",
"bmi_category",
]
LAZY_EVAL_REGRESSION_TARGETS = [
"age",
"height",
"weight",
"systolic",
"diastolic",
"heart_rate",
"bmi",
]
LAZY_EVAL_ALL_TARGETS = LAZY_EVAL_CLASSIFICATION_TARGETS + LAZY_EVAL_REGRESSION_TARGETS
# =============================================================================
# Metadata Categorization Helpers
# =============================================================================
def _categorize_age(age: Optional[int]) -> Optional[str]:
if age is None:
return None
if age <= 35:
return "Young (21-35)"
elif age <= 55:
return "Middle (36-55)"
elif age <= 75:
return "Senior (56-75)"
else:
return "Old (76+)"
def _categorize_height(height: Optional[int]) -> Optional[str]:
if height is None:
return None
if height <= 157:
return "Short (145-157)"
elif height <= 170:
return "Below Avg (158-170)"
elif height <= 183:
return "Above Avg (171-183)"
else:
return "Tall (184+)"
def _categorize_weight(weight: Optional[int]) -> Optional[str]:
if weight is None:
return None
if weight <= 52:
return "Light (36-52)"
elif weight <= 69:
return "Below Avg (53-69)"
elif weight <= 86:
return "Above Avg (70-86)"
else:
return "Heavy (87+)"
def _categorize_systolic(sbp: Optional[int]) -> Optional[str]:
if sbp is None:
return None
if sbp < 90:
return "Low (<90)"
elif sbp < 120:
return "Normal (90-119)"
elif sbp < 140:
return "Elevated (120-139)"
else:
return "High (140+)"
def _categorize_diastolic(dbp: Optional[int]) -> Optional[str]:
if dbp is None:
return None
if dbp < 60:
return "Low (<60)"
elif dbp < 80:
return "Normal (60-79)"
elif dbp < 90:
return "Elevated (80-89)"
else:
return "High (90+)"
def _categorize_heart_rate(hr: Optional[int]) -> Optional[str]:
if hr is None:
return None
if hr <= 65:
return "Low (52-65)"
elif hr <= 79:
return "Normal (66-79)"
elif hr <= 93:
return "Elevated (80-93)"
else:
return "High (94+)"
def _categorize_bmi(bmi: Optional[float]) -> Optional[str]:
if bmi is None:
return None
if bmi < 18.5:
return "Underweight (<18.5)"
elif bmi < 25:
return "Normal (18.5-24.9)"
elif bmi < 30:
return "Overweight (25-29.9)"
else:
return "Obese (30+)"
def _safe_int(val, sentinel=0):
"""Convert a value to int, returning None if it equals the sentinel or is invalid."""
if val is None or (isinstance(val, float) and pd.isna(val)):
return None
try:
v = int(val)
return None if v == sentinel else v
except (ValueError, TypeError):
return None
def _safe_float(val, sentinel=0.0):
"""Convert a value to float, returning None if it equals the sentinel or is invalid."""
if val is None or (isinstance(val, float) and pd.isna(val)):
return None
try:
v = float(val)
return None if v == sentinel else v
except (ValueError, TypeError):
return None
def _extract_metadata_fields(metadata: Dict) -> Dict[str, Any]:
"""
Extract and sanitize metadata fields from a PPGBPLoader metadata dict.
PPGBPLoader renames some columns (sex, age, sysbp, diasbp, hr, bmi) but
leaves others with original names (Height(cm), Weight(kg), Hypertension).
fillna(0) is applied, so 0 may represent missing values for some fields.
Returns a dict with standardized keys.
"""
# Sex: "M"/"F" string, or 0 if missing (from fillna)
sex = metadata.get("sex", None)
if sex is not None and sex != 0 and str(sex).strip() in ("M", "F"):
sex = str(sex).strip()
else:
sex = None
age = _safe_int(metadata.get("age", None))
height = _safe_int(metadata.get("Height(cm)", None))
weight = _safe_int(metadata.get("Weight(kg)", None))
sysbp = _safe_int(metadata.get("sysbp", None))
diasbp = _safe_int(metadata.get("diasbp", None))
hr = _safe_int(metadata.get("hr", None))
bmi = _safe_float(metadata.get("bmi", None))
# Hypertension: may be string (e.g. "Stage 2 hypertension") or numeric; keep as string
hypertension = metadata.get("Hypertension", None)
if hypertension is not None and not (isinstance(hypertension, float) and pd.isna(hypertension)):
hypertension = str(hypertension).strip()
else:
hypertension = None
return {
"sex": sex,
"age": age,
"height": height,
"weight": weight,
"sysbp": sysbp,
"diasbp": diasbp,
"hr": hr,
"bmi": bmi,
"hypertension": hypertension,
"age_category": _categorize_age(age),
"height_category": _categorize_height(height),
"weight_category": _categorize_weight(weight),
"systolic_category": _categorize_systolic(sysbp),
"diastolic_category": _categorize_diastolic(diasbp),
"heart_rate_category": _categorize_heart_rate(hr),
"bmi_category": _categorize_bmi(bmi),
}
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class Chronos_2_Embedder:
"""
Extracts fixed-dimensional embeddings from univariate PPG signals using Chronos-2.
Architecture:
Input PPG Signal -> Patching -> Encoder (6 layers) -> Hidden States -> Pooling -> Embedding
For univariate PPG:
Input shape: (1, 1, L) -- batch=1, variates=1, length=L
Output shape: (512,) -- d_model=512 after pooling
Pooling Strategies:
- mean: Average across all patch tokens (excluding special tokens)
- cls: Use the [REG] token embedding (similar to BERT's [CLS])
"""
def __init__(
self,
model_name: str = "amazon/chronos-2",
device_map: str = None,
pooling_strategy: str = "mean",
):
self.pooling_strategy = pooling_strategy
if device_map is None:
device_map = "cuda" if torch.cuda.is_available() else "cpu"
self.device_map = device_map
self.device = torch.device(
"cuda" if device_map.startswith("cuda") and torch.cuda.is_available() else "cpu"
)
print(f"[INFO] Loading Chronos-2 model: {model_name}")
print(f"[INFO] Device: {device_map}")
print(f"[INFO] Pooling strategy: {pooling_strategy}")
self.pipeline: Chronos2Pipeline = Chronos2Pipeline.from_pretrained(
model_name,
device_map=device_map,
)
@torch.no_grad()
def compute_embedding(self, signal: np.ndarray) -> np.ndarray:
"""
Compute embedding for a single univariate PPG signal.
Args:
signal: 1D numpy array of PPG signal values (variable length OK)
Returns:
Embedding vector of shape (d_model,) = (512,)
"""
# Reshape to (B=1, V=1, L) for Chronos-2 input
x_input = signal.reshape(1, 1, -1).astype(np.float32)
# Get encoder hidden states
# Returns: list of tensors (one per batch item), each (V, num_patches+2, d_model)
embeddings_list, _ = self.pipeline.embed(x_input)
emb = embeddings_list[0] # (1, num_patches+2, 512) for V=1
if self.pooling_strategy == "cls":
# Use the [REG] token (first token)
pooled = emb[0, 0, :] # (d_model,)
else: # "mean" pooling (default)
# Average across all patch tokens, skip [REG] (first) and masked patch (last)
pooled = emb[0, 1:-1, :].mean(dim=0) # (d_model,)
return pooled.cpu().numpy().astype(np.float32)
def extract_embeddings(
self,
base_dir: str,
split: str = "all",
) -> Dataset:
"""
Extract embeddings from all subjects/segments in the PPG-BP dataset.
Each subject has 3 PPG segments. Each segment is embedded independently
via Chronos-2, resulting in 3 data points per subject sharing the same metadata.
Args:
base_dir: Root directory of PPG-BP dataset (containing 0_subject/ and .xlsx)
split: Data split ('train', 'val', 'test', 'all')
Returns:
HuggingFace Dataset with columns:
- subject_id, segment_id (identifiers)
- sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension
- age_category, height_category, weight_category, systolic_category, etc.
- embedding (512-dim vector)
"""
loader = PPGBPLoader(base_dir=base_dir, split="all")
n_subjects = len(loader)
train_id_set = set(int(x) for x in loader.train_ids)
test_id_set = set(int(x) for x in loader.test_ids)
print(f"[INFO] Loaded {n_subjects} subjects (train={len(train_id_set)}, test={len(test_id_set)})")
# Accumulator lists
all_embeddings = []
all_subject_ids = []
all_segment_ids = []
all_splits = []
all_sexes = []
all_ages = []
all_heights = []
all_weights = []
all_sysbps = []
all_diasbps = []
all_hrs = []
all_bmis = []
all_hypertensions = []
all_age_categories = []
all_height_categories = []
all_weight_categories = []
all_systolic_categories = []
all_diastolic_categories = []
all_hr_categories = []
all_bmi_categories = []
total_segments = 0
for subj_idx in range(n_subjects):
metadata, signals = loader[subj_idx]
subject_id_int = int(metadata.get("subject_ID", 0))
subject_id = str(subject_id_int)
split_label = "train" if subject_id_int in train_id_set else "test"
# Extract and categorize metadata once per subject
fields = _extract_metadata_fields(metadata)
# Process each PPG segment
for seg_idx, signal in enumerate(signals):
emb = self.compute_embedding(signal)
all_embeddings.append(emb.tolist())
all_subject_ids.append(subject_id)
all_segment_ids.append(seg_idx + 1)
all_splits.append(split_label)
all_sexes.append(fields["sex"])
all_ages.append(fields["age"])
all_heights.append(fields["height"])
all_weights.append(fields["weight"])
all_sysbps.append(fields["sysbp"])
all_diasbps.append(fields["diasbp"])
all_hrs.append(fields["hr"])
all_bmis.append(fields["bmi"])
all_hypertensions.append(fields["hypertension"])
all_age_categories.append(fields["age_category"])
all_height_categories.append(fields["height_category"])
all_weight_categories.append(fields["weight_category"])
all_systolic_categories.append(fields["systolic_category"])
all_diastolic_categories.append(fields["diastolic_category"])
all_hr_categories.append(fields["heart_rate_category"])
all_bmi_categories.append(fields["bmi_category"])
total_segments += 1
if (subj_idx + 1) % 20 == 0 or subj_idx == n_subjects - 1:
print(
f"[INFO] Processed {subj_idx + 1}/{n_subjects} subjects "
f"({total_segments} segments)"
)
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"subject_id": all_subject_ids,
"segment_id": all_segment_ids,
"split": all_splits,
"sex": all_sexes,
"age": all_ages,
"age_category": all_age_categories,
"height": all_heights,
"height_category": all_height_categories,
"weight": all_weights,
"weight_category": all_weight_categories,
"systolic": all_sysbps,
"systolic_category": all_systolic_categories,
"diastolic": all_diasbps,
"diastolic_category": all_diastolic_categories,
"heart_rate": all_hrs,
"heart_rate_category": all_hr_categories,
"bmi": all_bmis,
"bmi_category": all_bmi_categories,
"hypertension": all_hypertensions,
"embedding": all_embeddings,
})
print(f"[INFO] Total segments: {len(result_dataset)}")
if len(result_dataset) > 0:
print(f"[INFO] Embedding dim: {len(result_dataset[0]['embedding'])}")
return result_dataset
def save_embeddings(self, dataset: Dataset, output_dir: str) -> None:
"""Save embeddings dataset to disk in HuggingFace format."""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
if len(dataset) > 0:
print(
f"[DONE] Total samples: {len(dataset)}, "
f"Embedding dim: {len(dataset[0]['embedding'])}"
)
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""Load saved embeddings dataset from disk."""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0,
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
Args:
embeddings: Array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity (typically 5-50). Rule of thumb: ~ sqrt(N)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0,
perplexity=perplexity,
max_iter=1000,
init="random",
learning_rate="auto",
)
return tsne.fit_transform(embeddings)
def create_scatter_plot_continuous(
coordinates: np.ndarray,
values: np.ndarray,
title: str,
output_path: str,
label: str = "Value",
cmap: str = "viridis",
) -> None:
"""
Create a 2D scatter plot colored by a continuous variable.
Args:
coordinates: 2D array of shape (num_points, 2)
values: Values for each point (may contain None)
title: Plot title
output_path: File path to save the plot (PDF)
label: Colorbar label
cmap: Matplotlib colormap name
"""
values_float = np.array([float(v) if v is not None else np.nan for v in values])
valid_mask = ~np.isnan(values_float)
valid_coords = coordinates[valid_mask]
valid_values = values_float[valid_mask]
if len(valid_values) == 0:
print(f"[WARN] No valid {label} data found. Skipping plot: {output_path}")
return
fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_values,
cmap=cmap,
alpha=0.7,
s=50,
)
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label(label)
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] {label} range: {valid_values.min():.1f} - {valid_values.max():.1f}")
def create_scatter_plot_categorical(
coordinates: np.ndarray,
categories: np.ndarray,
title: str,
output_path: str,
label: str = "Category",
) -> None:
"""
Create a 2D scatter plot colored by a categorical variable.
Args:
coordinates: 2D array of shape (num_points, 2)
categories: Category values for each point (may contain None)
title: Plot title
output_path: File path to save the plot (PDF)
label: Legend title
"""
valid_mask = np.array([c is not None and str(c) != "None" for c in categories])
valid_coords = coordinates[valid_mask]
valid_cats = np.array(categories)[valid_mask]
if len(valid_cats) == 0:
print(f"[WARN] No valid {label} data found. Skipping plot: {output_path}")
return
unique_cats = sorted(set(str(c) for c in valid_cats))
colormap = plt.cm.tab10 if len(unique_cats) <= 10 else plt.cm.tab20
colors = {cat: colormap(i % 20) for i, cat in enumerate(unique_cats)}
fig, ax = plt.subplots(figsize=(10, 8))
for cat in unique_cats:
mask = np.array([str(c) == cat for c in valid_cats])
ax.scatter(
valid_coords[mask, 0],
valid_coords[mask, 1],
c=[colors[cat]],
label=cat,
alpha=0.7,
s=50,
)
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(title=label, loc="best", markerscale=2)
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] {label} categories: {unique_cats}")
# =============================================================================
# LazyPredict multi-seed helper
# =============================================================================
def _run_lazy_one_split(
X: np.ndarray,
y_raw: List[Any],
subject_ids_valid: List[int],
train_id_set: set,
test_id_set: set,
is_classification: bool,
verbose: int = 0,
ignore_warnings: bool = True,
) -> pd.DataFrame:
"""Run LazyClassifier or LazyRegressor for one train/test split; return metrics DataFrame."""
train_mask = np.array([int(s) in train_id_set for s in subject_ids_valid])
test_mask = np.array([int(s) in test_id_set for s in subject_ids_valid])
if train_mask.sum() == 0 or test_mask.sum() == 0:
raise ValueError("Empty train or test set for this split.")
X_train, X_test = X[train_mask], X[test_mask]
y_train_raw = [y_raw[i] for i in range(len(y_raw)) if train_mask[i]]
y_test_raw = [y_raw[i] for i in range(len(y_raw)) if test_mask[i]]
if is_classification:
le = LabelEncoder()
y_train = le.fit_transform([str(v) for v in y_train_raw])
y_test = le.transform([str(v) for v in y_test_raw])
clf = LazyClassifier(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
models, _ = clf.fit(X_train, X_test, y_train, y_test)
else:
y_train = np.array([float(v) for v in y_train_raw], dtype=np.float64)
y_test = np.array([float(v) for v in y_test_raw], dtype=np.float64)
reg = LazyRegressor(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
models, _ = reg.fit(X_train, X_test, y_train, y_test)
# Ensure model name is index for aggregation across seeds
if "Model" in models.columns:
models = models.set_index("Model")
return models
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
CLI for Chronos-2 PPG-BP embedding extraction and t-SNE visualization.
Commands:
extract - Compute Chronos-2 embeddings from PPG signals
plot - Generate t-SNE scatter plots colored by metadata
lazy_eval - LazyPredict: X=embedding, y=target (default hypertension); single split
lazy_eval_multi_seed - LazyPredict over 10 seeds for all targets; save {target}.csv (mean +/- std)
"""
def extract(
self,
base_dir: str,
out_dir: str,
model: str = "amazon/chronos-2",
pooling: str = "mean",
split: str = "all",
) -> None:
"""
Extract Chronos-2 embeddings from PPG-BP signals.
Args:
base_dir: PPG-BP dataset root (contains 0_subject/ folder and .xlsx)
out_dir: Output directory for the HuggingFace embeddings dataset
model: Chronos-2 model name (default: amazon/chronos-2)
pooling: Pooling strategy - 'mean' or 'cls' (default: mean)
split: Data split - 'train', 'val', 'test', 'all' (default: all)
"""
if pooling not in ("mean", "cls"):
raise ValueError(f"Invalid pooling: {pooling}. Use 'mean' or 'cls'.")
embedder = Chronos_2_Embedder(
model_name=model,
pooling_strategy=pooling,
)
dataset = embedder.extract_embeddings(base_dir, "all")
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str,
out_dir: str,
perplexity: float = 30.0,
subjects: str = None,
num_subjects: int = 0,
) -> None:
"""
Visualize embeddings with t-SNE, colored by each metadata column.
Generates plots for (1) train, (2) test, (3) train+test, each with all categories.
Args:
emb_dir: Directory containing the saved HuggingFace embeddings dataset
out_dir: Output directory for PDF plots (subdirs: train/, test/, all/)
perplexity: t-SNE perplexity parameter (default: 30.0)
subjects: Comma-separated subject IDs to include (e.g., '2,6,8')
num_subjects: Include only first N subjects, 0 = all (default: 0)
"""
# Load saved embeddings (must have "split" column)
full_dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
if "split" not in full_dataset.column_names:
raise ValueError(
"Embeddings dataset missing 'split' column. Re-run extract with updated PPGBPLoader."
)
# Optional subject filtering
if subjects:
subject_list = [s.strip() for s in subjects.split(",")]
full_dataset = full_dataset.filter(lambda x: x["subject_id"] in subject_list)
print(f"[INFO] Filtered to subjects: {subject_list}")
elif num_subjects > 0:
all_subjects = sorted(set(full_dataset["subject_id"]))
selected = all_subjects[:num_subjects]
full_dataset = full_dataset.filter(lambda x: x["subject_id"] in selected)
print(f"[INFO] Selected first {num_subjects} subjects: {selected}")
continuous_vars = [
("age", "Age (years)", "viridis"),
("height", "Height (cm)", "plasma"),
("weight", "Weight (kg)", "cividis"),
("systolic", "Systolic BP (mmHg)", "Reds"),
("diastolic", "Diastolic BP (mmHg)", "Blues"),
("heart_rate", "Heart Rate (bpm)", "Oranges"),
("bmi", "BMI (kg/m²)", "Greens"),
]
categorical_vars = [
("subject_id", "Subject ID"),
("sex", "Sex"),
("hypertension", "Hypertension"),
("age_category", "Age Category"),
("height_category", "Height Category"),
("weight_category", "Weight Category"),
("systolic_category", "Systolic BP Category"),
("diastolic_category", "Diastolic BP Category"),
("heart_rate_category", "Heart Rate Category"),
("bmi_category", "BMI Category"),
]
for split_name in ["train", "test", "all"]:
if split_name == "all":
dataset = full_dataset
else:
dataset = full_dataset.filter(lambda x: x["split"] == split_name)
n = len(dataset)
if n == 0:
print(f"[WARN] No samples for split '{split_name}', skipping.")
continue
split_out_dir = os.path.join(out_dir, split_name)
os.makedirs(split_out_dir, exist_ok=True)
print(f"[INFO] t-SNE for split '{split_name}': {n} samples -> {split_out_dir}")
embeddings = np.array(dataset["embedding"])
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
for var_name, var_label, cmap in continuous_vars:
if var_name in dataset.column_names:
values = dataset[var_name]
create_scatter_plot_continuous(
coordinates_2d,
values,
f"Chronos-2 t-SNE [{split_name}] (Colored by {var_label})",
os.path.join(split_out_dir, f"tsne_by_{var_name}.pdf"),
label=var_label,
cmap=cmap,
)
for var_name, var_label in categorical_vars:
if var_name in dataset.column_names:
categories = dataset[var_name]
create_scatter_plot_categorical(
coordinates_2d,
categories,
f"Chronos-2 t-SNE [{split_name}] (Colored by {var_label})",
os.path.join(split_out_dir, f"tsne_by_{var_name}.pdf"),
label=var_label,
)
print(f"\n[DONE] All plots saved under: {out_dir} (train/, test/, all/)")
def lazy_eval(
self,
emb_dir: str,
target: str = "hypertension",
out_csv: str = None,
verbose: int = 0,
ignore_warnings: bool = True,
) -> None:
"""
Run LazyPredict (many classifiers or regressors) on Chronos embeddings (X) vs target (y).
Uses the same train/test split as PPGBPLoader (split column in the embeddings dataset).
Args:
emb_dir: Directory containing the saved HuggingFace embeddings dataset
target: Target column for y. Default 'hypertension'.
Classification: hypertension, sex, age_category, height_category,
weight_category, systolic_category, diastolic_category,
heart_rate_category, bmi_category
Regression: age, height, weight, systolic, diastolic, heart_rate, bmi
out_csv: If set, save the models comparison DataFrame to this path (e.g. lazy_hypertension.csv)
verbose: LazyPredict verbose (0 or 1)
ignore_warnings: LazyPredict ignore_warnings
"""
if not HAS_LAZYPREDICT:
raise ImportError("lazypredict is required. Install with: pip install lazypredict")
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
if "split" not in dataset.column_names:
raise ValueError(
"Embeddings dataset missing 'split' column. Re-run extract with updated PPGBPLoader."
)
if target not in dataset.column_names:
raise ValueError(
f"Target '{target}' not in dataset. Available: {dataset.column_names}"
)
X = np.array(dataset["embedding"], dtype=np.float64)
y_raw = list(dataset[target])
# Drop rows where target is missing
valid = [
i for i, v in enumerate(y_raw)
if v is not None and not (isinstance(v, float) and pd.isna(v)) and str(v).strip() != ""
]
if len(valid) == 0:
raise ValueError(f"No valid non-missing values for target '{target}'.")
X = X[valid]
y_raw = [y_raw[i] for i in valid]
splits = [dataset["split"][i] for i in valid]
train_mask = np.array([s == "train" for s in splits])
test_mask = np.array([s == "test" for s in splits])
X_train, X_test = X[train_mask], X[test_mask]
y_train_raw = [y_raw[i] for i in range(len(y_raw)) if splits[i] == "train"]
y_test_raw = [y_raw[i] for i in range(len(y_raw)) if splits[i] == "test"]
is_classification = target in LAZY_EVAL_CLASSIFICATION_TARGETS
if target in LAZY_EVAL_REGRESSION_TARGETS and not is_classification:
is_classification = False
elif target not in LAZY_EVAL_REGRESSION_TARGETS:
is_classification = True
if is_classification:
le = LabelEncoder()
y_train = le.fit_transform([str(v) for v in y_train_raw])
y_test = le.transform([str(v) for v in y_test_raw])
print(f"[INFO] LazyClassifier: X=embedding, y={target} (classes: {list(le.classes_)})")
print(f"[INFO] Train: {len(y_train)}, Test: {len(y_test)}")
clf = LazyClassifier(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
models, predictions = clf.fit(X_train, X_test, y_train, y_test)
else:
y_train = np.array([float(v) for v in y_train_raw], dtype=np.float64)
y_test = np.array([float(v) for v in y_test_raw], dtype=np.float64)
print(f"[INFO] LazyRegressor: X=embedding, y={target}")
print(f"[INFO] Train: {len(y_train)}, Test: {len(y_test)}")
reg = LazyRegressor(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
models, predictions = reg.fit(X_train, X_test, y_train, y_test)
print("\n[RESULTS] Models (sorted by score):")
print(models)
if out_csv:
d = os.path.dirname(out_csv)
if d:
os.makedirs(d, exist_ok=True)
models.to_csv(out_csv)
print(f"\n[DONE] Saved model comparison to {out_csv}")
return None
def lazy_eval_multi_seed(
self,
emb_dir: str,
base_dir: str,
out_dir: str = "./lazy_multi_seed",
n_seeds: int = 10,
verbose: int = 0,
ignore_warnings: bool = True,
targets: str = "all",
) -> None:
"""
Run LazyPredict with n_seeds different train/test splits (from PPGBPLoader seeds).
For each target (classification + regression), saves {target}.csv with mean and std
for all metrics across seeds.
This can run for a long time (10 seeds x 16 targets x many models). For a quick test
use e.g. --n_seeds 2 --targets hypertension
"""
if not HAS_LAZYPREDICT:
raise ImportError("lazypredict is required. Install with: pip install lazypredict")
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
if "subject_id" not in dataset.column_names:
raise ValueError("Embeddings dataset missing 'subject_id' column.")
target_list = (
[t.strip() for t in targets.split(",")] if targets != "all" else LAZY_EVAL_ALL_TARGETS
)
os.makedirs(out_dir, exist_ok=True)
n_targets = len(target_list)
print(f"[INFO] lazy_eval_multi_seed: {n_targets} targets, {n_seeds} seeds each (can take a long time).")
for ti, target in enumerate(target_list):
if target not in dataset.column_names:
print(f"[WARN] Skipping target '{target}' (not in dataset).")
continue
print(f"[INFO] Target '{target}' ({ti + 1}/{n_targets}) ...")
y_raw = list(dataset[target])
valid = [
i
for i, v in enumerate(y_raw)
if v is not None
and not (isinstance(v, float) and pd.isna(v))
and str(v).strip() != ""
]
if len(valid) == 0:
print(f"[WARN] Skipping target '{target}' (no valid values).")
continue
X = np.array(dataset["embedding"], dtype=np.float64)[valid]
y_valid = [y_raw[i] for i in valid]
subject_ids_valid = [dataset["subject_id"][i] for i in valid]
is_classification = target in LAZY_EVAL_CLASSIFICATION_TARGETS
if target in LAZY_EVAL_REGRESSION_TARGETS and not is_classification:
is_classification = False
elif target not in LAZY_EVAL_REGRESSION_TARGETS:
is_classification = True
run_dfs = []
for seed in range(n_seeds):
print(f"[INFO] seed {seed + 1}/{n_seeds} ...", flush=True)
loader = PPGBPLoader(base_dir=base_dir, split="all", seed=seed)
train_id_set = set(int(x) for x in loader.train_ids)
test_id_set = set(int(x) for x in loader.test_ids)
try:
df = _run_lazy_one_split(
X,
y_valid,
subject_ids_valid,
train_id_set,
test_id_set,
is_classification,
verbose=verbose,
ignore_warnings=ignore_warnings,
)
run_dfs.append(df)
except Exception as e:
print(f"[WARN] Seed {seed} for target '{target}' failed: {e}")
continue
if len(run_dfs) == 0:
print(f"[WARN] No successful runs for target '{target}', skipping.")
continue
combined = pd.concat(run_dfs, axis=0)
metric_cols = [
c for c in combined.columns
if c != "Model" and np.issubdtype(combined[c].dtype, np.number)
]
agg_mean = combined.groupby(level=0).mean()
agg_std = combined.groupby(level=0).std().fillna(0)
out = pd.DataFrame(index=agg_mean.index)
out.index.name = "Model"
for col in metric_cols:
out[col + "_mean"] = agg_mean[col].values
out[col + "_std"] = agg_std[col].values
m = agg_mean[col].round(4).astype(str)
s = agg_std[col].round(4).astype(str)
out[col + "_mean_pm_std"] = m + " +/- " + s
out_path = os.path.join(out_dir, target + ".csv")
out.to_csv(out_path)
print(f"[DONE] {target}: {len(run_dfs)} seeds -> {out_path}")
if __name__ == "__main__":
Fire(CLI)

View File

@@ -0,0 +1,55 @@
#!/bin/bash
# =============================================================================
# Chronos-2 PPG-BP Embedding Extraction & t-SNE Visualization
# =============================================================================
set -euo pipefail
# ---- Configuration (edit these variables) ------------------------------------
BASE_DIR="/scratch/10608/aadharsh_aadhithya/repos/data/tsllm"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
MODEL="amazon/chronos-2"
POOLING="mean"
SPLIT="all"
PERPLEXITY=30.0
EMB_DIR="${BASE_DIR}/chronos2_embeddings"
PLOT_DIR="${BASE_DIR}/chronos2_plots"
# ------------------------------------------------------------------------------
echo "============================================="
echo "Chronos-2 PPG-BP Pipeline"
echo "============================================="
echo "Base dir: ${BASE_DIR}"
echo "Model: ${MODEL}"
echo "Pooling: ${POOLING}"
echo "Split: ${SPLIT}"
echo "Embeddings dir: ${EMB_DIR}"
echo "Plots dir: ${PLOT_DIR}"
echo "============================================="
# Step 1: Extract Chronos-2 embeddings from PPG signals
echo ""
echo "[Step 1/2] Extracting Chronos-2 embeddings..."
python "${SCRIPT_DIR}/gen_plot.py" extract \
--base_dir "${BASE_DIR}" \
--out_dir "${EMB_DIR}" \
--model "${MODEL}" \
--pooling "${POOLING}" \
--split "${SPLIT}"
# Step 2: Generate t-SNE plots colored by metadata
echo ""
echo "[Step 2/2] Generating t-SNE plots..."
python "${SCRIPT_DIR}/gen_plot.py" plot \
--emb_dir "${EMB_DIR}" \
--out_dir "${PLOT_DIR}" \
--perplexity "${PERPLEXITY}"
echo ""
echo "============================================="
echo "Done! Outputs:"
echo " Embeddings: ${EMB_DIR}"
echo " Plots: ${PLOT_DIR}"
echo "============================================="

View File

@@ -60,7 +60,7 @@ from sentence_transformers import SentenceTransformer
# Constants
# =============================================================================
SUBJECT_PATH = "LLM_Health/tsllm_personalization_icl/analysis/user_similarity/sbert_metadata_ppgbp/PPGBP_metadata.xlsx"
SUBJECT_PATH = "/scratch/10608/aadharsh_aadhithya/repos/data/tsllm/PPGBP_metadata.xlsx"
# =============================================================================
@@ -149,6 +149,45 @@ class SBERT_Metadata:
"""
self.model = SentenceTransformer("all-MiniLM-L6-v2")
def discover_user_files(self, data_root: str) -> List[Tuple[str, str, str]]:
"""
Discover all user data files under the data root.
Expected structure: data_root/0_subject/{user_id}_{session_id}.txt
Args:
data_root: Root directory containing data files
Returns:
List of (user_id, session_id, file_path) tuples
"""
discovered_files = []
# Data is in 0_subject folder
subject_root = os.path.join(data_root, "0_subject")
print(f"[DEBUG] Looking for user files in: {subject_root}")
print(f"[DEBUG] Directory exists: {os.path.exists(subject_root)}")
# Use glob to find all .txt files
for file_path in sorted(glob(os.path.join(subject_root, "*.txt"))):
# Extract user_id and session_id from filename
# Pattern: {user_id}_{session_id}.txt (e.g., 100_1.txt)
filename = os.path.basename(file_path)
name_without_ext = os.path.splitext(filename)[0]
parts = name_without_ext.split("_")
if len(parts) >= 2:
user_id = parts[0]
session_id = parts[1]
discovered_files.append((user_id, session_id, file_path))
print(f"[DEBUG] Found {len(discovered_files)} data files")
if discovered_files:
print(f"[DEBUG] First few files: {[(u, s) for u, s, _ in discovered_files[:5]]}")
return discovered_files
def textualize_metadata_ppg_bp(self,
sex: Optional[str],
age: Optional[int],
@@ -236,9 +275,9 @@ class SBERT_Metadata:
# Create sentence from metadata
if age is not None:
sentence = f"This is the information of the user, sex: {sex_text}, age: {age_text}, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
sentence = f"This is the information of the user, sex: {sex_text}, age: {age_text}, height: {height_text}, weight: {weight_text}, systolic blood pressure: {sbp_text}, diastolic blood pressure: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
else:
sentence = f"This is the information of the user, sex: {sex_text}, age: unknown, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
sentence = f"This is the information of the user, sex: {sex_text}, age: unknown, height: {height_text}, weight: {weight_text}, systolic blood pressure: {sbp_text}, diastolic blood pressure: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
return sentence
@@ -321,11 +360,15 @@ class SBERT_Metadata:
print(f"[INFO] Loading subject metadata from: {subject_path}")
subject_metadata = load_subject_metadata(subject_path)
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
print(f"[DEBUG] Sample metadata keys: {list(subject_metadata.keys())[:10]}")
# Discover user data files
user_files = self.discover_user_files(data_root)
print(f"[INFO] Discovered {len(user_files)} data files")
all_embeddings = []
all_user_ids = []
# all_session_ids = []
all_session_ids = []
all_idxs = []
all_labels = []
all_sexes = []
@@ -337,26 +380,24 @@ class SBERT_Metadata:
all_heart_rates = []
all_bmis = []
all_hypertensions = []
all_age_categories = []
all_height_categories = []
all_weight_categories = []
all_systolic_categories = []
all_diastolic_categories = []
all_heart_rate_categories = []
all_bmi_categories = []
all_file_paths = []
# Collect metadata for all samples first
for user_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# Shuffle dataset for randomness
dataset = dataset.shuffle(seed=0)
# Filter by sleep stage label if specified
if label is not None and label != "all":
dataset = dataset.filter(lambda x: x["label"] == label)
num_samples = len(dataset)
if num_samples == 0:
continue
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Collect metadata for all files
matched_count = 0
unmatched_users = set()
for idx, (user_id, session_id, file_path) in enumerate(user_files):
# Get metadata for this user
# Convert user_id to string format that matches metadata keys
user_id_str = str(int(user_id))
if idx == 0:
print(f"[DEBUG] First file user_id: '{user_id}' -> lookup key: '{user_id_str}'")
try:
sex = subject_metadata[user_id_str]["sex"]
age = subject_metadata[user_id_str]["age"]
@@ -364,11 +405,11 @@ class SBERT_Metadata:
weight = subject_metadata[user_id_str]["weight"]
sbp = subject_metadata[user_id_str]["sbp"]
dbp = subject_metadata[user_id_str]["dbp"]
heart_rate = subject_metadata[user_id_str]["heart_rate"]
heart_rate = subject_metadata[user_id_str]["hr"]
bmi = subject_metadata[user_id_str]["bmi"]
hypertension = subject_metadata[user_id_str]["hypertension"]
matched_count += 1
except KeyError:
sex = None
age = None
@@ -379,16 +420,126 @@ class SBERT_Metadata:
heart_rate = None
bmi = None
hypertension = None
print(f"[WARN] No metadata found for user_id: {user_id_str}")
unmatched_users.add(user_id_str)
# Collect all samples for this user/session
for i in range(num_samples):
all_user_ids.append(str(dataset["user_id"][i]))
all_session_ids.append(str(dataset["session_id"][i]))
all_idxs.append(int(dataset["idx"][i]))
all_labels.append(str(dataset["label"][i]))
all_ages.append(age)
all_sexes.append(sex)
# Collect data for this file
all_user_ids.append(user_id)
all_session_ids.append(session_id)
all_idxs.append(idx)
all_labels.append(label if label else "all")
all_ages.append(age)
all_sexes.append(sex)
all_heights.append(height)
all_weights.append(weight)
all_systolics.append(sbp)
all_diastolics.append(dbp)
all_heart_rates.append(heart_rate)
all_bmis.append(bmi)
all_hypertensions.append(hypertension)
# Categorize age
if age is not None:
if age <= 35:
age_category = "Young (21-35)"
elif age <= 55:
age_category = "Middle (36-55)"
elif age <= 75:
age_category = "Senior (56-75)"
else:
age_category = "Old (76+)"
else:
age_category = None
all_age_categories.append(age_category)
# Categorize height (cm): based on quartiles 145-196
if height is not None:
if height <= 157:
height_category = "Short (145-157)"
elif height <= 170:
height_category = "Below Avg (158-170)"
elif height <= 183:
height_category = "Above Avg (171-183)"
else:
height_category = "Tall (184+)"
else:
height_category = None
all_height_categories.append(height_category)
# Categorize weight (kg): based on quartiles 36-103
if weight is not None:
if weight <= 52:
weight_category = "Light (36-52)"
elif weight <= 69:
weight_category = "Below Avg (53-69)"
elif weight <= 86:
weight_category = "Above Avg (70-86)"
else:
weight_category = "Heavy (87+)"
else:
weight_category = None
all_weight_categories.append(weight_category)
# Categorize systolic BP (mmHg): clinical categories
if sbp is not None:
if sbp < 90:
systolic_category = "Low (<90)"
elif sbp < 120:
systolic_category = "Normal (90-119)"
elif sbp < 140:
systolic_category = "Elevated (120-139)"
else:
systolic_category = "High (140+)"
else:
systolic_category = None
all_systolic_categories.append(systolic_category)
# Categorize diastolic BP (mmHg): clinical categories
if dbp is not None:
if dbp < 60:
diastolic_category = "Low (<60)"
elif dbp < 80:
diastolic_category = "Normal (60-79)"
elif dbp < 90:
diastolic_category = "Elevated (80-89)"
else:
diastolic_category = "High (90+)"
else:
diastolic_category = None
all_diastolic_categories.append(diastolic_category)
# Categorize heart rate (bpm): based on quartiles 52-106
if heart_rate is not None:
if heart_rate <= 65:
hr_category = "Low (52-65)"
elif heart_rate <= 79:
hr_category = "Normal (66-79)"
elif heart_rate <= 93:
hr_category = "Elevated (80-93)"
else:
hr_category = "High (94+)"
else:
hr_category = None
all_heart_rate_categories.append(hr_category)
# Categorize BMI (kg/m²): clinical categories
if bmi is not None:
if bmi < 18.5:
bmi_category = "Underweight (<18.5)"
elif bmi < 25:
bmi_category = "Normal (18.5-24.9)"
elif bmi < 30:
bmi_category = "Overweight (25-29.9)"
else:
bmi_category = "Obese (30+)"
else:
bmi_category = None
all_bmi_categories.append(bmi_category)
all_file_paths.append(file_path)
print(f"[INFO] Collected metadata for {len(all_user_ids)} samples")
print(f"[INFO] Matched {matched_count} files with metadata, {len(unmatched_users)} unique users unmatched")
if unmatched_users:
print(f"[DEBUG] Unmatched user IDs: {sorted(unmatched_users)[:10]}{'...' if len(unmatched_users) > 10 else ''}")
# Generate embeddings from metadata in batches
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
@@ -397,9 +548,20 @@ class SBERT_Metadata:
batch_sexes = all_sexes[batch_start:batch_end]
batch_ages = all_ages[batch_start:batch_end]
batch_heights = all_heights[batch_start:batch_end]
batch_weights = all_weights[batch_start:batch_end]
batch_systolics = all_systolics[batch_start:batch_end]
batch_diastolics = all_diastolics[batch_start:batch_end]
batch_heart_rates = all_heart_rates[batch_start:batch_end]
batch_bmis = all_bmis[batch_start:batch_end]
batch_hypertensions = all_hypertensions[batch_start:batch_end]
# Compute embeddings from metadata
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
embeddings = self.compute_embedding_from_metadata(
batch_sexes, batch_ages, batch_heights, batch_weights,
batch_systolics, batch_diastolics, batch_heart_rates,
batch_bmis, batch_hypertensions
)
# Collect embeddings
for i in range(embeddings.shape[0]):
@@ -408,18 +570,26 @@ class SBERT_Metadata:
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
# "session_id": all_session_ids,
"session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"sex": all_sexes,
"age": all_ages,
"age_category": all_age_categories,
"height": all_heights,
"height_category": all_height_categories,
"weight": all_weights,
"weight_category": all_weight_categories,
"systolic": all_systolics,
"systolic_category": all_systolic_categories,
"diastolic": all_diastolics,
"diastolic_category": all_diastolic_categories,
"heart_rate": all_heart_rates,
"heart_rate_category": all_heart_rate_categories,
"bmi": all_bmis,
"bmi_category": all_bmi_categories,
"hypertension": all_hypertensions,
"file_path": all_file_paths,
"embedding": all_embeddings,
})
@@ -451,16 +621,19 @@ class SBERT_Metadata:
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
if len(dataset) > 0:
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
else:
print(f"[WARN] Dataset is empty - no samples were processed")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load saved embeddings dataset from disk.
Args:
embedding_dir: Directory path containing the saved HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
@@ -469,6 +642,94 @@ class SBERT_Metadata:
return dataset
# =============================================================================
# Save metadata embeddings by user ID (for PPGBP dataloader)
# =============================================================================
def _normalize_sex_for_sbert(sex: Any) -> Optional[int]:
"""Convert metadata sex to int for SBERT textualize (1=Female, 0=Male)."""
if sex is None or (isinstance(sex, float) and np.isnan(sex)):
return None
s = str(sex).strip().upper()
if s in ("F", "1"):
return 1
if s in ("M", "0"):
return 0
return None
def _normalize_hypertension_for_sbert(ht: Any) -> Optional[int]:
"""Convert metadata hypertension to int if numeric."""
if ht is None or (isinstance(ht, float) and np.isnan(ht)):
return None
try:
return int(float(ht))
except (ValueError, TypeError):
return None
def save_metadata_embeddings_by_userid(
data_root: str,
subject_path: str = SUBJECT_PATH,
embeddings_root: Optional[str] = None,
) -> None:
"""
Compute SBERT metadata embeddings for each user and save one file per user
under embeddings_root as {user_id}.npy (e.g. 100.npy).
Uses PPGBP_METADATA_EMBEDDINGS_ROOT from environment if embeddings_root
is not provided. Load .env from project root if present so the env var is set.
Args:
data_root: Root directory containing 0_subject/ with PPG-BP data.
subject_path: Path to PPGBP metadata xlsx (subject_ID, age, sex, etc.).
embeddings_root: Directory to write {user_id}.npy files. Defaults to
os.environ["PPGBP_METADATA_EMBEDDINGS_ROOT"].
"""
try:
from dotenv import load_dotenv
# Load .env from repo root (parent of analysis/)
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
load_dotenv(os.path.join(_repo_root, ".env"))
except ImportError:
pass
root = embeddings_root or os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT")
if not root:
raise ValueError(
"embeddings_root not provided and PPGBP_METADATA_EMBEDDINGS_ROOT not set. "
"Set the env var or pass embeddings_root."
)
os.makedirs(root, exist_ok=True)
subject_metadata = load_subject_metadata(subject_path)
embedder = SBERT_Metadata()
user_ids = sorted(subject_metadata.keys())
if not user_ids:
raise ValueError(f"No subjects found in {subject_path}")
for user_id in user_ids:
meta = subject_metadata[user_id]
sex = _normalize_sex_for_sbert(meta.get("sex"))
age = meta.get("age")
height = meta.get("height")
weight = meta.get("weight")
sbp = meta.get("sbp")
dbp = meta.get("dbp")
hr = meta.get("hr")
bmi = meta.get("bmi")
hypertension = _normalize_hypertension_for_sbert(meta.get("hypertension"))
if age is not None and isinstance(age, float) and np.isnan(age):
age = None
emb = embedder.compute_embedding_from_metadata(
[sex], [age], [height], [weight], [sbp], [dbp], [hr], [bmi], [hypertension]
)
path = os.path.join(root, f"{user_id}.npy")
np.save(path, emb[0])
print(f"[DONE] Saved {len(user_ids)} metadata embeddings under {root}")
# =============================================================================
# Data Loading Utilities
# =============================================================================
@@ -625,11 +886,242 @@ def create_scatter_plot_by_age(
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[DONE] Saved plot: {d}")
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
def create_scatter_plot_continuous(
coordinates: np.ndarray,
values: np.ndarray,
title: str,
output_path: str,
label: str = "Value",
cmap: str = "viridis"
) -> None:
"""
Create and save a 2D scatter plot colored by a continuous variable.
Args:
coordinates: 2D array of shape (num_points, 2)
values: Values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot
label: Label for the colorbar
cmap: Colormap name
"""
# Filter out points with missing data
values_float = np.array([float(v) if v is not None else np.nan for v in values])
valid_mask = ~np.isnan(values_float)
valid_coords = coordinates[valid_mask]
valid_values = values_float[valid_mask]
if len(valid_values) == 0:
print(f"[WARN] No valid {label} data found. Skipping plot: {output_path}")
return
fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_values,
cmap=cmap,
alpha=0.7,
s=50
)
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label(label)
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] {label} range: {valid_values.min():.1f} - {valid_values.max():.1f}")
def create_scatter_plot_categorical(
coordinates: np.ndarray,
categories: np.ndarray,
title: str,
output_path: str,
label: str = "Category"
) -> None:
"""
Create and save a 2D scatter plot colored by a categorical variable.
Args:
coordinates: 2D array of shape (num_points, 2)
categories: Category values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot
label: Label for the legend
"""
# Filter out points with missing data
valid_mask = np.array([c is not None and str(c) != 'None' for c in categories])
valid_coords = coordinates[valid_mask]
valid_cats = np.array(categories)[valid_mask]
if len(valid_cats) == 0:
print(f"[WARN] No valid {label} data found. Skipping plot: {output_path}")
return
# Get unique categories and assign colors
unique_cats = sorted(set(str(c) for c in valid_cats))
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_cats)))
cat_to_color = {cat: colors[i] for i, cat in enumerate(unique_cats)}
fig, ax = plt.subplots(figsize=(10, 8))
for cat in unique_cats:
mask = np.array([str(c) == cat for c in valid_cats])
ax.scatter(
valid_coords[mask, 0],
valid_coords[mask, 1],
c=[cat_to_color[cat]],
label=f"{cat}",
alpha=0.7,
s=50
)
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(title=label, loc='best')
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] {label} categories: {unique_cats}")
def gen_2d_plot(
subject_path: str = SUBJECT_PATH,
out_dir: str = "./plots",
add_jitter: bool = True,
jitter_amount: float = 0.15
) -> None:
"""
Generate a simple 2D scatter plot with hypertension category on x-axis
and systolic blood pressure on y-axis.
This function reads metadata directly from the Excel file and creates
a scatter plot without any embedding or dimensionality reduction.
Args:
subject_path: Path to the PPGBP_metadata.xlsx file
out_dir: Output directory for the plot
add_jitter: Whether to add horizontal jitter for better visibility (default: True)
jitter_amount: Amount of jitter to add (default: 0.15)
"""
os.makedirs(out_dir, exist_ok=True)
# Load metadata
print(f"[INFO] Loading metadata from: {subject_path}")
subject_metadata = load_subject_metadata(subject_path)
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
# Extract hypertension and systolic BP values
hypertensions = []
systolics = []
for subject_id, metadata in subject_metadata.items():
hypertension = metadata.get("hypertension")
sbp = metadata.get("sbp")
# Only include if both values are present
if hypertension is not None and sbp is not None:
hypertensions.append(str(hypertension))
systolics.append(sbp)
if len(hypertensions) == 0:
print("[WARN] No valid data found with both hypertension and systolic BP values.")
return
print(f"[INFO] Found {len(hypertensions)} subjects with valid data")
# Map hypertension categories to numeric x positions
# Expected categories: "0" (Normal), "1" (Prehypertension), "2" (Stage 1), "3" (Stage 2)
unique_cats = sorted(set(hypertensions))
cat_to_x = {cat: i for i, cat in enumerate(unique_cats)}
# Create x positions (with optional jitter for visibility)
x_positions = np.array([cat_to_x[h] for h in hypertensions], dtype=float)
if add_jitter:
x_positions += np.random.uniform(-jitter_amount, jitter_amount, size=len(x_positions))
y_values = np.array(systolics)
# Assign colors by category
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_cats)))
cat_to_color = {cat: colors[i] for i, cat in enumerate(unique_cats)}
point_colors = [cat_to_color[h] for h in hypertensions]
# Create hypertension labels for legend
hypertension_labels = {
"0": "Normal",
"1": "Prehypertension",
"2": "Stage 1 Hypertension",
"3": "Stage 2 Hypertension"
}
# Create the plot
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for legend
for cat in unique_cats:
mask = np.array([h == cat for h in hypertensions])
cat_label = hypertension_labels.get(cat, cat)
ax.scatter(
x_positions[mask],
y_values[mask],
c=[cat_to_color[cat]],
label=cat_label,
alpha=0.7,
s=60,
edgecolors='white',
linewidth=0.5
)
# Configure axes
ax.set_xticks(range(len(unique_cats)))
ax.set_xticklabels([hypertension_labels.get(cat, cat) for cat in unique_cats], rotation=15, ha='right')
ax.set_xlabel("Hypertension Category")
ax.set_ylabel("Systolic Blood Pressure (mmHg)")
ax.set_title("Systolic Blood Pressure by Hypertension Category")
# Add horizontal grid lines
ax.yaxis.grid(True, linestyle='--', alpha=0.7)
ax.set_axisbelow(True)
# Add legend
ax.legend(title="Hypertension", loc='upper left')
# Save the plot
output_path = os.path.join(out_dir, "hypertension_vs_systolic.pdf")
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] Hypertension categories: {unique_cats}")
print(f"[INFO] Systolic BP range: {y_values.min():.0f} - {y_values.max():.0f} mmHg")
# Print category statistics
print("\n[INFO] Category statistics:")
for cat in unique_cats:
cat_mask = np.array([h == cat for h in hypertensions])
cat_systolics = y_values[cat_mask]
cat_label = hypertension_labels.get(cat, cat)
print(f" {cat_label}: n={len(cat_systolics)}, "
f"mean={cat_systolics.mean():.1f}, "
f"std={cat_systolics.std():.1f}, "
f"range=[{cat_systolics.min():.0f}-{cat_systolics.max():.0f}]")
# =============================================================================
# Command Line Interface
# =============================================================================
@@ -729,19 +1221,81 @@ class CLI:
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Extract ages for coloring
ages = np.array(dataset["age"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualization colored by age
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
# Generate visualizations for continuous variables
continuous_vars = [
("age", "Age (years)", "viridis"),
("height", "Height (cm)", "plasma"),
("weight", "Weight (kg)", "cividis"),
("systolic", "Systolic BP (mmHg)", "Reds"),
("diastolic", "Diastolic BP (mmHg)", "Blues"),
("heart_rate", "Heart Rate (bpm)", "Oranges"),
("bmi", "BMI (kg/m²)", "Greens"),
]
for var_name, var_label, cmap in continuous_vars:
if var_name in dataset.column_names:
values = np.array(dataset[var_name])
create_scatter_plot_continuous(
coordinates_2d,
values,
f"t-SNE Visualization (Colored by {var_label})",
os.path.join(out_dir, f"tsne_by_{var_name}.pdf"),
label=var_label,
cmap=cmap
)
# Generate visualizations for categorical variables
categorical_vars = [
("sex", "Sex"),
("hypertension", "Hypertension"),
("age_category", "Age Category"),
("height_category", "Height Category"),
("weight_category", "Weight Category"),
("systolic_category", "Systolic BP Category"),
("diastolic_category", "Diastolic BP Category"),
("heart_rate_category", "Heart Rate Category"),
("bmi_category", "BMI Category"),
]
for var_name, var_label in categorical_vars:
if var_name in dataset.column_names:
categories = np.array(dataset[var_name])
create_scatter_plot_categorical(
coordinates_2d,
categories,
f"t-SNE Visualization (Colored by {var_label})",
os.path.join(out_dir, f"tsne_by_{var_name}.pdf"),
label=var_label
)
def plot_2d(
self,
subject_path: str = SUBJECT_PATH,
out_dir: str = "./plots",
add_jitter: bool = True,
jitter_amount: float = 0.15
) -> None:
"""
Generate a simple 2D scatter plot with hypertension category on x-axis
and systolic blood pressure on y-axis.
This reads metadata directly from the Excel file and creates a simple
scatter plot without embeddings or dimensionality reduction.
Args:
subject_path: Path to the PPGBP_metadata.xlsx file
out_dir: Output directory for the plot
add_jitter: Whether to add horizontal jitter for better visibility (default: True)
jitter_amount: Amount of jitter to add (default: 0.15)
Usage:
python gen_plot.py plot_2d --out_dir ./plots
python gen_plot.py plot_2d --subject_path /path/to/metadata.xlsx --out_dir ./plots
"""
gen_2d_plot(subject_path, out_dir, add_jitter, jitter_amount)
if __name__ == "__main__":

1029
sample_case.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,32 @@
# PPGBP Recruiter (MAB-based example set selection) config
# Usage: python -m sc.run_recruiter sc/config/ppgbp_recruiter.yaml
# When model_path is null: use ground-truth scorer (arm_id -> score), no LLM; run gt_steps to test MAB convergence.
data_path: $PPGBP_DATA_ROOT
feature_root: $DATA_INDEX_ROOT/metadata_onehot # index.json + embeddings.npy
log_path: ./logs/ppgbp_recruiter
model_path: null # set to a HF model path to use real LLM + self_certainty
queue_size: 5
recruit_size: 2
n_way: 2
k_shot: 3
num_arms: 50
# When model_path is null: deterministic reward = Gaussian over arm index, peak at gt_best_arm_id
gt_steps: 200
gt_best_arm_id: 20
gt_sigma: 5.0
n_das: 5
delta: 0.05
epsilon: 0.11
temperature: 0.7
max_new_tokens: 256
num_seeds: 1
seed: 42
task_info: "Classify the subject from PPG-BP metadata."
classes_info: ["Normal", "Prehypertension", "Stage 1 hypertension", "Stage 2 hypertension" ]

View File

@@ -65,14 +65,9 @@ temperature: 0.0
# ------------------------------------------------------------------------------
# Multiple Ollama instances provide model diversity even with T=0
models:
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
- hf:openai/gpt-oss-20b
# Add more entries for model pool (same model, multiple instances for self-consistency)
# - hf:/path/to/gpt-oss-20b
# ------------------------------------------------------------------------------
# Output Configuration

View File

@@ -1,147 +1,142 @@
"""
Base agent for single-turn LLM inference with structured JSON parsing.
Each invoke call is stateless: system prompt + user prompt -> (text, logits).
Message format: {"role": "user"|"assistant"|"system", "content": str}
"""
import os
import re
import sys
import json
from typing import List, Optional, Tuple
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
import torch
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.model import AsyncModelPool, ChatMessage
class Agent:
"""Wraps an AsyncModelPool for single-turn inference."""
def __init__(
self,
name,
model_pool,
log_path,
):
name: str,
model_pool: AsyncModelPool,
log_path: str,
system_message: str = "",
) -> None:
self.name = name
self.model_pool = model_pool
self.log_path = log_path
self.system_message = system_message
# Logging directories
self.root_log_path = log_path
os.makedirs(log_path, exist_ok=True)
self.agent_log_path = os.path.join(log_path, name)
os.makedirs(self.agent_log_path, exist_ok=True)
self.long_term_memory = []
self.short_term_memory = []
self.volatile_memory = []
# -----------------------------------------------------------------
# Logging
# -----------------------------------------------------------------
def log(self, message, local=True):
path = os.path.join(self.root_log_path, "log.txt")
with open(path, "a", encoding="utf-8") as f:
message_type = "UNKNOWN"
if isinstance(message, SystemMessage):
message_type = "SYSTEM"
if isinstance(message, HumanMessage):
message_type = "HUMAN"
if isinstance(message, AIMessage):
message_type = "AI"
content = message.content.strip()
name = self.name
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
if local:
local_path = os.path.join(self.agent_log_path, "log.txt")
with open(local_path, "a", encoding="utf-8") as f:
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
def log(self, message: ChatMessage) -> None:
"""Append a message to the global and agent-local log."""
role = message.get("role", "unknown")
content = message.get("content", "").strip()
entry = f"[{self.name}] [{role}]\n{content}\n\n\n"
def update_memory(self):
self.long_term_memory.extend(self.short_term_memory)
self.clean_short_term_memory()
self.clean_volatile_memory()
for path in (
os.path.join(self.root_log_path, "log.txt"),
os.path.join(self.agent_log_path, "log.txt"),
):
with open(path, "a", encoding="utf-8") as f:
f.write(entry)
def clean_short_term_memory(self):
self.short_term_memory = []
# -----------------------------------------------------------------
# JSON helpers
# -----------------------------------------------------------------
def clean_volatile_memory(self):
self.volatile_memory = []
def clean_long_term_memory(self):
self.long_term_memory = []
def clean_json_text(self, text):
@staticmethod
def _clean_json_text(text: str) -> str:
"""Normalise common LLM quirks so the string is valid JSON."""
text = text.strip()
text = text.replace("", "'").replace("", "'")
text = text.replace("", "'").replace("", "'")
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text)
text = re.sub(r",\s*}", "}", text)
text = re.sub(r",\s*]", "]", text)
text = "".join(ch for ch in text if ch.isprintable())
text = text.replace("][", ",")
text = text.replace("\u2018", "'").replace("\u2019", "'") # smart single quotes
text = text.replace("\u201c", "'").replace("\u201d", "'") # smart double quotes
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text) # escape stray backslashes
text = re.sub(r",\s*}", "}", text) # trailing comma in object
text = re.sub(r",\s*]", "]", text) # trailing comma in array
text = "".join(ch for ch in text if ch.isprintable()) # strip control chars
text = text.replace("][", ",") # merge adjacent arrays
return text
def safe_parse_json(self, text):
def safe_parse_json(self, text: str) -> Optional[dict]:
"""Best-effort extraction of a JSON object from an LLM response."""
if not text:
return None
text = text.strip()
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match and not text.endswith("}"):
match = re.search(r"\{.*\}", text + "}", re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
return json.loads(self._clean_json_text(match.group(0)))
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
elif not text.endswith("}"):
text += "}"
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
print("[!] JSON parse failed")
print("[!] JSON parse failed: no object found")
return None
async def validate_response(self, response, fields, volatile=False):
if (
not response
or not isinstance(response, dict)
or not all(field in response for field in fields)
):
print("[!] The JSON failed to be parsed. Trying again.")
content = (
"Failed to parse the JSON from the previous response. Please try again."
# -----------------------------------------------------------------
# Response validation
# -----------------------------------------------------------------
async def validate_response(
self, response: Optional[dict], fields: List[str]
) -> Optional[dict]:
"""Validate that *response* is a dict containing all *fields*; retry once if not."""
if not isinstance(response, dict) or not all(f in response for f in fields):
print("[!] JSON validation failed. Retrying...")
text, _ = await self.invoke(
"Failed to parse the JSON from the previous response. Please try again.",
)
response = await self.invoke(content, volatile=volatile)
response = self.safe_parse_json(response)
if (
not response
or not isinstance(response, dict)
or not all(field in response for field in fields)
):
response = self.safe_parse_json(text)
if not isinstance(response, dict) or not all(f in response for f in fields):
print("[!] Retry failed.")
return None
return response
def get_last_response(self):
if len(self.long_term_memory) >= 2:
last_msg = self.long_term_memory[-1]
if isinstance(last_msg, AIMessage):
return self.safe_parse_json(last_msg.content)
return None
# -----------------------------------------------------------------
# Core invoke
# -----------------------------------------------------------------
def set_system_message(self, content, local=True):
system_message = SystemMessage(content=content)
self.log(system_message, local)
self.long_term_memory.append(system_message)
async def invoke(
self,
content: str,
) -> Tuple[Optional[str], Optional[torch.Tensor]]:
"""
Single-turn call: [system_message] + user content -> (text, logits).
Returns:
text: Generated reply string (or None on error).
logits: Tensor of shape (1, num_tokens, vocab_size) on CPU (or None).
"""
messages: List[ChatMessage] = []
if self.system_message:
messages.append({"role": "system", "content": self.system_message})
user_msg: ChatMessage = {"role": "user", "content": content}
messages.append(user_msg)
async def invoke(self, content, volatile=False, local=True):
messages = self.long_term_memory.copy()
if volatile:
messages.extend(self.volatile_memory)
else:
messages.extend(self.short_term_memory)
messages.append(HumanMessage(content=content))
try:
response = await self.model_pool.invoke(messages)
if volatile:
self.volatile_memory.extend([HumanMessage(content=content), response])
else:
self.short_term_memory.extend([HumanMessage(content=content), response])
local_ = not volatile and local
self.log(HumanMessage(content=content), local=local_)
self.log(response, local=local_)
return response.content.strip()
except Exception as e: # pylint: disable=broad-exception-caught
print(f"[Error] Error occurred while invoking LLM: {e}")
text, logits = await self.model_pool.invoke(messages)
assistant_msg: ChatMessage = {"role": "assistant", "content": text}
self.log(user_msg)
self.log(assistant_msg)
return text.strip(), logits
except Exception as e:
print(f"[Error] invoke failed: {e}")
return None, None

View File

@@ -1,72 +1,630 @@
"""
Data loader for personalized ICL experiments.
Loads a target user's test split and example data from all other users.
Expects the following directory structure:
<path>/
info.json # dataset metadata (task, classes, features)
<user_id>/
1/ # examples (train split) - HuggingFace dataset on disk
2/ # test split - HuggingFace dataset on disk
<other_user>/
1/
...
--------------------------------------------------------------------------------
Integrating a new dataset
--------------------------------------------------------------------------------
The pipeline expects:
1) Directory layout (for DataLoader / ShuffledDataLoader):
- <data_path>/info.json
- <data_path>/<user_id>/1/ and <data_path>/<user_id>/2/ (HuggingFace datasets saved with datasets.Dataset.save_to_disk)
2) info.json schema:
- "task": str (e.g. "Sleep stage classification from EEG")
- "class": dict mapping label -> description (e.g. {"W": "Wake", "N1": "Non-REM 1", ...})
- "feature": str (sensor/feature description for the LLM)
3) Each sample (in test and example datasets) must be a dict with:
- "label": str (class name, must be one of the keys in info.json["class"])
- "features": dict of str -> value (e.g. {"Fpz-Cz": "0.12, -0.05, ...", "Pz-Oz": "..."}); values are formatted for the prompt
Option A - Convert your dataset to this format:
- Create info.json and per-user dirs; build HuggingFace Dataset with columns "label" and "features", then save_to_disk to <user_id>/1 and <user_id>/2.
- Use prepare_dataset_for_sc() in this module to write from in-memory lists.
Option B - Use in-memory data without changing directory layout:
- Use InMemoryDataLoader(metadata, test_samples, example_samples) which implements the same interface as DataLoader.
- Use it in run_sc by constructing it and passing to run_single_task; for run_confidence_based / run_consistency_based / run_sc_queue_random you currently need the directory layout (or extend those runners to accept a loader instance).
"""
import os
import json
import datasets
import numpy as np
from glob import glob
from typing import Optional, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Union
import datasets
class DataLoader:
"""Loads a target user's test set and ICL examples from other users."""
def __init__(
self,
data_path,
user_id,
example_pool="out",
continuous=True,
):
if not os.path.exists(os.path.join(data_path, "info.json")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
return
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
self.example_dataset = datasets.Dataset.from_list([])
users = glob(os.path.join(data_path, "*"))
users = [path.split("/")[-1] for path in users]
if "info.json" in users:
users.remove("info.json")
for user in users:
if example_pool == "out" and user == user_id:
continue
if example_pool == "in" and user != user_id:
continue
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
self,
path: Optional[str] = None,
user_id: Optional[str] = None,
shuffle: bool = False,
seed: int = 0,
data_path: Optional[str] = None,
**kwargs: Any,
) -> None:
# Support both path= and data_path= for compatibility with run_sc
root = path or data_path
if root is None or user_id is None:
raise ValueError("path (or data_path) and user_id are required")
metadata_path = os.path.join(root, "info.json")
target_user_path = os.path.join(root, user_id, "2")
if not continuous:
self.test_dataset = self.test_dataset.shuffle(seed=0)
self.example_dataset = self.example_dataset.shuffle(seed=0)
if not os.path.exists(metadata_path) or not os.path.exists(target_user_path):
return
def __len__(self):
# Metadata
with open(metadata_path, "r", encoding="utf-8") as f:
self.metadata: Dict[str, Any] = json.load(f)
# Test set for the target user
self.test_dataset = datasets.load_from_disk(target_user_path)
# Example set: train splits from all *other* users
example_paths = [
os.path.join(user_path, "1")
for user_path in glob(os.path.join(root, "*"))
if os.path.isdir(user_path) and os.path.basename(user_path) != user_id
]
example_parts = [datasets.load_from_disk(p) for p in example_paths]
if example_parts:
self.example_dataset = datasets.concatenate_datasets(example_parts)
else:
self.example_dataset = datasets.Dataset.from_list([])
# Shuffle
if shuffle:
self.test_dataset = self.test_dataset.shuffle(seed=seed)
self.example_dataset = self.example_dataset.shuffle(seed=seed)
def __len__(self) -> int:
return len(self.test_dataset)
def __getitem__(self, idx):
sample = self.test_dataset[idx]
return sample
def __getitem__(self, idx: int) -> Dict[str, Any]:
return self.test_dataset[idx]
def __iter__(self):
for sample in self.test_dataset:
yield sample
yield from self.test_dataset
def get_examples(self):
def get_examples(self) -> datasets.Dataset:
return self.example_dataset
def get_metadata(self):
def get_metadata(self) -> Dict[str, Any]:
return self.metadata
def get_sensor_info(self):
def get_sensor_info(self) -> str:
return self.metadata["feature"]
def get_task_info(self):
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
classes_info = "\n".join(classes_info)
task_info += f"**Classes**:\n{classes_info}"
return task_info
def get_task_info(self) -> str:
"""Return a formatted string describing the task and its classes."""
classes_info = "\n".join(
f" - {k}: {v}" for k, v in self.metadata["class"].items()
)
return f"**Task**:\n{self.metadata['task']}\n\n**Classes**:\n{classes_info}"
def get_classes_info(self):
classes_info = [k for k in self.metadata["class"].keys()]
return classes_info
def get_classes_info(self) -> List[str]:
return list(self.metadata["class"].keys())
class InMemoryDataLoader:
"""
DataLoader interface backed by in-memory lists. Use this to run the SC
pipeline on a new dataset without writing the directory layout to disk.
Implements the same interface as DataLoader: __len__, __getitem__, __iter__,
get_examples(), get_metadata(), get_sensor_info(), get_task_info(), get_classes_info().
"""
def __init__(
self,
metadata: Dict[str, Any],
test_samples: List[Dict[str, Any]],
example_samples: List[Dict[str, Any]],
) -> None:
"""
Args:
metadata: Must have "task", "class" (dict label -> description), "feature" (str).
test_samples: List of dicts with "label" and "features" (dict).
example_samples: Same schema; used as ICL example pool.
"""
self.metadata = metadata
self.test_dataset = datasets.Dataset.from_list(test_samples)
self.example_dataset = datasets.Dataset.from_list(example_samples)
def __len__(self) -> int:
return len(self.test_dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
return self.test_dataset[idx]
def __iter__(self):
yield from self.test_dataset
def get_examples(self) -> datasets.Dataset:
return self.example_dataset
def get_metadata(self) -> Dict[str, Any]:
return self.metadata
def get_sensor_info(self) -> str:
return self.metadata["feature"]
def get_task_info(self) -> str:
classes_info = "\n".join(
f" - {k}: {v}" for k, v in self.metadata["class"].items()
)
return f"**Task**:\n{self.metadata['task']}\n\n**Classes**:\n{classes_info}"
def get_classes_info(self) -> List[str]:
return list(self.metadata["class"].keys())
def prepare_dataset_for_sc(
output_path: str,
metadata: Dict[str, Any],
per_user_splits: Dict[str, Dict[str, List[Dict[str, Any]]]],
) -> None:
"""
Write a new dataset to the directory layout expected by DataLoader and
ShuffledDataLoader. Call this once; then point config data_path to output_path.
Args:
output_path: Root directory to create (e.g. /path/to/MyDataset).
metadata: info.json contents: "task", "class" (dict), "feature" (str).
per_user_splits: { user_id: { "train": [samples], "test": [samples] } }.
Each sample is a dict with "label" and "features".
"""
os.makedirs(output_path, exist_ok=True)
with open(os.path.join(output_path, "info.json"), "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
for user_id, splits in per_user_splits.items():
user_dir = os.path.join(output_path, str(user_id))
os.makedirs(os.path.join(user_dir, "1"), exist_ok=True)
os.makedirs(os.path.join(user_dir, "2"), exist_ok=True)
train_ds = datasets.Dataset.from_list(splits.get("train", []))
test_ds = datasets.Dataset.from_list(splits.get("test", []))
train_ds.save_to_disk(os.path.join(user_dir, "1"))
test_ds.save_to_disk(os.path.join(user_dir, "2"))
# PPGBP data loader
"""
PPGBP Dataset Loader
A data loader/iterator for the PPG-BP dataset that reads xlsx metadata
and corresponding signal text files, with optional batching support.
Metadata embeddings live under PPGBP_METADATA_EMBEDDINGS_ROOT (one .npy per
subject ID). If missing, they are generated automatically in PPGBPLoader.__init__
using SBERT. Set PPGBP_METADATA_EMBEDDINGS_ROOT in .env or environment.
"""
import os
import sys
import numpy as np
import pandas as pd
from typing import Iterator, Tuple, Dict, List, Optional, Union
try:
from dotenv import load_dotenv
_ppgbp_env_loaded = False
def _ensure_ppgbp_env():
global _ppgbp_env_loaded
if not _ppgbp_env_loaded:
load_dotenv()
_ppgbp_env_loaded = True
except ImportError:
def _ensure_ppgbp_env():
pass
def _get_ppgbp_embedding_root() -> str:
_ensure_ppgbp_env()
root = os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT")
if not root:
raise ValueError(
"PPGBP_METADATA_EMBEDDINGS_ROOT is not set. Set it in .env or environment "
"(e.g. PPGBP_METADATA_EMBEDDINGS_ROOT=/path/to/metadata_embeddings)."
)
return root
def _get_ppgbp_embedding_root_optional() -> Optional[str]:
"""Return PPGBP_METADATA_EMBEDDINGS_ROOT if set, else None. Used when embeddings are not required (e.g. recruiter with feature_root)."""
_ensure_ppgbp_env()
root = os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT") or ""
return root.strip() or None
def _get_missing_embedding_ids(embedding_root: str, subject_ids: np.ndarray) -> Optional[List[int]]:
"""
Return list of subject_ids that do not have a .npy file in embedding_root.
Return None if embedding_root is not a directory (so caller can create and generate).
"""
if not os.path.isdir(embedding_root):
return None
missing = []
for sid in subject_ids:
path = os.path.join(embedding_root, f"{int(sid)}.npy")
if not os.path.isfile(path):
missing.append(int(sid))
return missing if missing else []
def _check_embedding_root_populated(embedding_root: str, subject_ids: np.ndarray) -> None:
"""Raise if embedding_root is missing or does not contain .npy for every subject_id."""
missing_or_none = _get_missing_embedding_ids(embedding_root, subject_ids)
if missing_or_none is None:
raise FileNotFoundError(
f"Metadata embeddings root is not a directory or does not exist: {embedding_root}."
)
if missing_or_none:
raise FileNotFoundError(
f"Metadata embeddings root {embedding_root} is missing .npy files for subject IDs: "
f"{missing_or_none[:10]}{'...' if len(missing_or_none) > 10 else ''}."
)
def _generate_and_save_ppgbp_metadata_embeddings(
embedding_root: str,
subject_path: str,
subject_ids: np.ndarray,
) -> None:
"""
Generate SBERT metadata embeddings for the given subject IDs and save as
{subject_id}.npy under embedding_root. Uses tqdm for progress.
"""
try:
from tqdm import tqdm
except ImportError:
tqdm = lambda x, **kw: x # noqa: E731
# Lazy import to keep SBERT/gen_plot deps out of normal import path
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from analysis.user_similarity.sbert_metadata_ppgbp.gen_plot import (
load_subject_metadata,
SBERT_Metadata,
_normalize_sex_for_sbert,
_normalize_hypertension_for_sbert,
)
os.makedirs(embedding_root, exist_ok=True)
subject_metadata = load_subject_metadata(subject_path)
embedder = SBERT_Metadata()
ids_to_generate = [int(sid) for sid in subject_ids]
for user_id in tqdm(ids_to_generate, desc="PPGBP metadata embeddings", unit="subject"):
user_id_str = str(user_id)
meta = subject_metadata.get(user_id_str, {})
sex = _normalize_sex_for_sbert(meta.get("sex"))
age = meta.get("age")
height = meta.get("height")
weight = meta.get("weight")
sbp = meta.get("sbp")
dbp = meta.get("dbp")
hr = meta.get("hr")
bmi = meta.get("bmi")
hypertension = _normalize_hypertension_for_sbert(meta.get("hypertension"))
if age is not None and isinstance(age, float) and np.isnan(age):
age = None
emb = embedder.compute_embedding_from_metadata(
[sex], [age], [height], [weight], [sbp], [dbp], [hr], [bmi], [hypertension]
)
path = os.path.join(embedding_root, f"{user_id}.npy")
np.save(path, emb[0])
def load_ppgbp_metadata_embedding(embedding_root: str, subject_id: int) -> np.ndarray:
"""Load a single metadata embedding for subject_id from embedding_root (utility)."""
path = os.path.join(embedding_root, f"{int(subject_id)}.npy")
if not os.path.isfile(path):
raise FileNotFoundError(f"Metadata embedding not found: {path}")
return np.load(path)
def save_ppgbp_metadata_embeddings(
embedding_root: str,
subject_path: str,
subject_ids: Optional[np.ndarray] = None,
) -> None:
"""
Generate and save PPGBP metadata embeddings under embedding_root (one .npy per subject).
If subject_ids is None, uses all subject IDs present in the metadata file at subject_path.
Uses tqdm for progress.
"""
if subject_ids is not None and len(subject_ids) == 0:
return
# If subject_ids is None we need to get all from metadata; _generate_* expects an array
if subject_ids is None:
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from analysis.user_similarity.sbert_metadata_ppgbp.gen_plot import load_subject_metadata
subject_metadata = load_subject_metadata(subject_path)
subject_ids = np.array(sorted(int(k) for k in subject_metadata.keys()))
_generate_and_save_ppgbp_metadata_embeddings(embedding_root, subject_path, subject_ids)
class PPGBPLoader:
"""
Data loader for the PPG-BP dataset.
Reads metadata from xlsx file and corresponding PPG signal text files.
Supports iteration over single samples or batches.
Subject list is built from 0_subject/ and matched with metadata;
train/test is 80/20 at subject level (random, fixed by seed).
"""
COLUMN_MAPPING = {
"Sex(M/F)": "sex",
"Age(year)": "age",
"Systolic Blood Pressure(mmHg)": "sysbp",
"Diastolic Blood Pressure(mmHg)": "diasbp",
"Heart Rate(b/m)": "hr",
"BMI(kg/m^2)": "bmi"
}
def __init__(
self,
base_dir: str,
split: str = 'all',
batch_size: Optional[int] = None,
shuffle: bool = False,
num_segments: int = 3,
seed: int = 42,
return_metadata_embeddings: bool = False,
):
self.base_dir = base_dir
self.split = split
self.batch_size = batch_size
self.shuffle = shuffle
self.num_segments = num_segments
self.seed = seed
self.return_metadata_embeddings = return_metadata_embeddings
self.rng = np.random.default_rng(seed)
self.signal_dir = os.path.join(base_dir, "0_subject")
self.xlsx_path = self._find_xlsx_file()
self.metadata = self._load_metadata()
self._all_subject_ids = self._get_unified_subject_ids()
self._train_ids, self._test_ids = self._get_train_test_split(self._all_subject_ids)
self.subject_ids = self._get_split_ids()
self.metadata = self.metadata[self.metadata['subject_ID'].isin(self.subject_ids)]
self._embedding_root = _get_ppgbp_embedding_root_optional()
if self._embedding_root is not None:
missing = _get_missing_embedding_ids(self._embedding_root, self._all_subject_ids)
if len(self._all_subject_ids) > 0 and (missing is None or len(missing) > 0):
_generate_and_save_ppgbp_metadata_embeddings(
self._embedding_root, self.xlsx_path, self._all_subject_ids
)
_check_embedding_root_populated(self._embedding_root, self.subject_ids)
self._current_idx = 0
self._indices = np.arange(len(self.subject_ids))
if self.shuffle:
self.rng.shuffle(self._indices)
def _find_xlsx_file(self) -> str:
for f in os.listdir(self.base_dir):
if f.endswith('.xlsx'):
return os.path.join(self.base_dir, f)
raise FileNotFoundError(f"No xlsx file found in {self.base_dir}")
def _load_metadata(self) -> pd.DataFrame:
# Use header=1: xlsx has a title row, then row with column names (subject_ID, etc.)
df = pd.read_excel(self.xlsx_path, header=1)
df = df.rename(columns=self.COLUMN_MAPPING)
df = df.fillna(0)
return df
def _get_unified_subject_ids(self) -> np.ndarray:
"""Subject IDs present in both 0_subject/ (from listdir) and metadata."""
if not os.path.isdir(self.signal_dir):
return np.array([], dtype=np.int64)
# Files are named {subject_id}_{segment}.txt
ids_from_files = set()
for f in os.listdir(self.signal_dir):
if f.endswith(".txt"):
try:
sid = int(f.replace(".txt", "").split("_")[0])
ids_from_files.add(sid)
except (ValueError, IndexError):
continue
meta_ids = set(self.metadata["subject_ID"].astype(int).values)
unified = sorted(ids_from_files & meta_ids)
return np.array(unified)
def _get_train_test_split(self, ids: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""80% train, 20% test at subject level; shuffle fixed by seed."""
if len(ids) == 0:
return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
perm = self.rng.permutation(len(ids))
ids_perm = ids[perm]
n_train = max(1, int(round(0.8 * len(ids))))
train_ids = ids_perm[:n_train]
test_ids = ids_perm[n_train:]
return train_ids, test_ids
def _get_split_ids(self) -> np.ndarray:
if self.split == "train":
return self._train_ids
elif self.split == "test":
return self._test_ids
elif self.split == "all":
return self._all_subject_ids
else:
raise ValueError(f"Unknown split: {self.split}. Use 'train', 'test', or 'all'.")
@property
def train_ids(self) -> np.ndarray:
"""Subject IDs in the train split (80%)."""
return self._train_ids
@property
def test_ids(self) -> np.ndarray:
"""Subject IDs in the test split (20%)."""
return self._test_ids
def _load_signal(self, subject_id: int) -> np.ndarray:
segments = []
for s in range(1, self.num_segments + 1):
filepath = os.path.join(self.signal_dir, f"{subject_id}_{s}.txt")
signal = pd.read_csv(filepath, sep='\t', header=None)
signal = signal.values.squeeze()
if len(signal) > 1:
signal = signal[:-1]
segments.append(signal)
return np.array(segments, dtype=object)
def _get_metadata_dict(self, subject_id: int) -> Dict:
row = self.metadata[self.metadata['subject_ID'] == subject_id].iloc[0]
return row.to_dict()
def _load_metadata_embedding(self, subject_id: int) -> np.ndarray:
return load_ppgbp_metadata_embedding(self._embedding_root, subject_id)
def __len__(self) -> int:
return len(self.subject_ids)
def __iter__(self) -> 'PPGBPLoader':
self._current_idx = 0
if self.shuffle:
self.rng.shuffle(self._indices)
return self
def __next__(self) -> Union[
Tuple[Dict, np.ndarray],
Tuple[Dict, np.ndarray, np.ndarray],
Tuple[List[Dict], List[np.ndarray]],
Tuple[List[Dict], List[np.ndarray], List[np.ndarray]],
]:
if self._current_idx >= len(self.subject_ids):
raise StopIteration
if self.batch_size is None:
idx = self._indices[self._current_idx]
subject_id = self.subject_ids[idx]
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
self._current_idx += 1
if self.return_metadata_embeddings and self._embedding_root is not None:
return metadata, signal, self._load_metadata_embedding(subject_id)
return metadata, signal
else:
end_idx = min(self._current_idx + self.batch_size, len(self.subject_ids))
batch_indices = self._indices[self._current_idx:end_idx]
metadata_list = []
signal_list = []
embedding_list = [] if (self.return_metadata_embeddings and self._embedding_root is not None) else None
for idx in batch_indices:
subject_id = self.subject_ids[idx]
metadata_list.append(self._get_metadata_dict(subject_id))
signal_list.append(self._load_signal(subject_id))
if self.return_metadata_embeddings and self._embedding_root is not None:
embedding_list.append(self._load_metadata_embedding(subject_id))
self._current_idx = end_idx
if embedding_list is not None:
return metadata_list, signal_list, embedding_list
return metadata_list, signal_list
def __getitem__(self, idx: int) -> Union[Tuple[Dict, np.ndarray], Tuple[Dict, np.ndarray, np.ndarray]]:
if idx < 0 or idx >= len(self.subject_ids):
raise IndexError(f"Index {idx} out of range")
subject_id = self.subject_ids[idx]
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
if self.return_metadata_embeddings and self._embedding_root is not None:
return metadata, signal, self._load_metadata_embedding(subject_id)
return metadata, signal
def get_by_subject_id(self, subject_id: int) -> Union[Tuple[Dict, np.ndarray], Tuple[Dict, np.ndarray, np.ndarray]]:
if subject_id not in self.subject_ids:
raise ValueError(f"Subject ID {subject_id} not found in this split.")
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
if self.return_metadata_embeddings and self._embedding_root is not None:
return metadata, signal, self._load_metadata_embedding(subject_id)
return metadata, signal
def reset(self):
self._current_idx = 0
if self.shuffle:
self.rng.shuffle(self._indices)
def iter_batches(
self, batch_size: int
) -> Iterator[Union[Tuple[List[Dict], List[np.ndarray]], Tuple[List[Dict], List[np.ndarray], List[np.ndarray]]]]:
indices = np.arange(len(self.subject_ids))
if self.shuffle:
self.rng.shuffle(indices)
for start_idx in range(0, len(self.subject_ids), batch_size):
end_idx = min(start_idx + batch_size, len(self.subject_ids))
batch_indices = indices[start_idx:end_idx]
metadata_list = []
signal_list = []
embedding_list = [] if (self.return_metadata_embeddings and self._embedding_root is not None) else None
for idx in batch_indices:
subject_id = self.subject_ids[idx]
metadata_list.append(self._get_metadata_dict(subject_id))
signal_list.append(self._load_signal(subject_id))
if self.return_metadata_embeddings and self._embedding_root is not None:
embedding_list.append(self._load_metadata_embedding(subject_id))
if embedding_list is not None:
yield metadata_list, signal_list, embedding_list
else:
yield metadata_list, signal_list
def get_loaders(
base_dir: str,
batch_size: Optional[int] = None,
shuffle_train: bool = True,
seed: int = 42
) -> Tuple[PPGBPLoader, PPGBPLoader]:
"""Convenience function to get train and test loaders (80/20 split)."""
train_loader = PPGBPLoader(
base_dir=base_dir, split="train",
batch_size=batch_size, shuffle=shuffle_train, seed=seed
)
test_loader = PPGBPLoader(
base_dir=base_dir, split="test",
batch_size=batch_size, shuffle=False, seed=seed
)
return train_loader, test_loader
if __name__ == "__main__":
print("PPGBPLoader ready to use.")

View File

@@ -151,4 +151,38 @@ class Queue:
"avg_survival": sum(durations) / len(durations),
"max_survival": max(durations),
"avg_usage": sum(usages) / len(usages) if usages else 0
}
}
def update_with_recruiter(self, results, new_example_sets, L: int = 1):
"""
Evict L lowest-scoring sets (by self-certainty from results) and push L new sets.
results: list of (case, response_dict_or_text, self_certainty_score) in queue order.
new_example_sets: list of L new cases (each case = list of indices, same as queue items).
Lower self_certainty = better; we evict the L with highest (worst) score.
"""
if not results or L <= 0:
return
current = list(self._queue)
if len(current) == 0:
for case in new_example_sets[: self._queue.maxlen]:
self._queue.append(list(case))
self._register_stats(case)
return
scores = []
for i, item in enumerate(current):
sc = -1.0
if i < len(results) and len(results[i]) >= 3:
sc = float(results[i][2])
scores.append((i, sc))
scores.sort(key=lambda x: x[1], reverse=True)
to_evict_idx = {scores[j][0] for j in range(min(L, len(scores)))}
new_queue = [current[i] for i in range(len(current)) if i not in to_evict_idx]
for idx in sorted(to_evict_idx, reverse=True):
if idx < len(current):
self._record_eviction_stats(current[idx])
for case in new_example_sets[:L]:
if len(new_queue) >= self._queue.maxlen:
break
new_queue.append(list(case))
self._register_stats(case)
self._queue = deque(new_queue, maxlen=self._queue.maxlen)

View File

@@ -1,153 +1,160 @@
import os
"""
HuggingFace Causal LM wrapper for inference with logits.
Loads a model via transformers, generates text, and returns both
the generated text and full logits via a forward pass.
Message format: [{"role": "user"|"assistant"|"system", "content": str}, ...]
"""
from __future__ import annotations
import asyncio
import requests
from typing import Dict, List, Optional, Tuple
from langchain_ollama import ChatOllama
from langchain_together import ChatTogether
from langchain_openai import ChatOpenAI
from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# {"role": "user"|"assistant"|"system", "content": str}
ChatMessage = Dict[str, str]
def load_models(models, temperature=0.0, num_ctx=15000):
model_pool = AsyncModelPool()
for model in models:
model_pool.add_model(Model(model, temperature=temperature, num_ctx=num_ctx))
model_pool.init_models()
return model_pool
# -----------------------------------------------------------------------------
# Model
# -----------------------------------------------------------------------------
class Model:
def __init__(self, model, temperature, num_ctx):
if model.startswith("ollama:"):
model = model.replace("ollama:", "")
if "url:" in model: # custom parsing for local ollama models
model = model.replace("url:", "")
base_url = model.split("/")[0]
model_type = model.split("/")[1]
# self.model = ChatOllama(
# model=model_type,
# base_url=f"http://{base_url}",
# temperature=temperature,
# num_ctx=num_ctx,
# )
self.model = None
self.base_url = f"http://{base_url}/api/chat"
self.model_type = model_type
self.temperature = temperature
self.num_ctx = num_ctx
else:
self.model = ChatOllama(
model=model.replace("ollama:", ""),
temperature=temperature,
num_ctx=num_ctx,
)
elif model.startswith("together"):
if "TOGETHER_API_KEY" not in os.environ:
print("[!] TOGETHER_API_KEY is not set")
assert 0
self.model = ChatTogether(
model=model.replace("together:", ""),
temperature=temperature,
max_tokens=num_ctx,
max_retries=3,
)
elif model.startswith("openai"):
if "OPENAI_API_KEY" not in os.environ:
print("[!] OPENAI_API_KEY is not set")
assert 0
self.model = ChatOpenAI(
model=model.replace("openai:", ""),
temperature=temperature,
)
else:
self.model = init_chat_model(
model=model,
temperature=temperature,
"""HuggingFace causal LM. Returns generated text + full logits."""
def __init__(
self,
model_path: str,
temperature: float = 0.0,
max_new_tokens: int = 1024,
enable_thinking: bool = False,
) -> None:
self._tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
self._model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.bfloat16,
device_map="auto",
).eval()
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self.enable_thinking = enable_thinking
def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
"""
Generate a reply and return (text, logits).
Returns:
text: Generated text string.
logits: Logits tensor (1, num_generated_tokens, vocab_size) on CPU.
"""
# Build prompt from chat messages
prompt = self._tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)
inputs = self._tokenizer(prompt, return_tensors="pt").to(self._model.device)
prompt_len = inputs["input_ids"].shape[1]
# Generate with scores in a single pass
with torch.inference_mode():
gen_out = self._model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=self.temperature > 0,
temperature=self.temperature if self.temperature > 0 else 1.0,
return_dict_in_generate=True,
output_scores=True,
)
def invoke(self, messages, logprobs=False, top_logprobs=0):
try:
if self.model:
response = self.model.invoke(messages)
return response
else:
converted_messages = []
for msg in messages:
role = msg.type
role = "user" if role == "human" else "assistant"
content = msg.content
converted_messages.append({"role": role, "content": content})
response = requests.post(self.base_url, json={
"model": self.model_type,
"messages": converted_messages,
"stream": False,
"options": {
"temperature": self.temperature,
"num_ctx": self.num_ctx,
},
"logprobs": logprobs,
"top_logprobs": top_logprobs,
})
response = response.json()
resp_msg = AIMessage(content=response["message"]["content"])
if logprobs:
return resp_msg, response["logprobs"]
else:
return resp_msg
return resp_msg, response["logprobs"]
except Exception as e:
print(f"[Error] Error occurred while invoking LLM: {e}")
return e
# Decode generated tokens
gen_ids = gen_out.sequences[0][prompt_len:]
text = self._tokenizer.decode(gen_ids, skip_special_tokens=True)
# Strip internal thinking tags (e.g. "analysis...assistantfinal...")
if "assistantfinal" in text:
text = text.split("assistantfinal", 1)[1].strip()
# Stack per-step scores into logits: (1, num_generated_tokens, vocab_size)
logits = torch.stack(gen_out.scores, dim=1) # scores: tuple of (batch, vocab)
return text, logits.cpu()
# -----------------------------------------------------------------------------
# Async wrappers
# -----------------------------------------------------------------------------
class AsyncModel:
def __init__(self, model):
"""Runs Model.invoke in a thread executor for async usage."""
def __init__(self, model: Model) -> None:
self.model = model
async def invoke(self, content, logprobs=False, top_logprobs=0):
loop = asyncio.get_event_loop()
if logprobs:
response, logprobs = await loop.run_in_executor(
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None,
lambda: self.model.invoke(content, logprobs=logprobs, top_logprobs=top_logprobs),
)
return response, logprobs
else:
response = await loop.run_in_executor(
None,
lambda: self.model.invoke(content),
)
return response
lambda: self.model.invoke(messages),
)
class AsyncModelPool:
def __init__(self):
self.models = []
self._available_models = None
self._model_semaphore = None
"""Round-robin pool of async models."""
def add_model(self, model):
def __init__(self) -> None:
self.models: List[Model] = []
self._queue: Optional[asyncio.Queue[AsyncModel]] = None
def add_model(self, model: Model) -> None:
self.models.append(model)
def init_models(self):
print(f"Initializing {len(self.models)} models...")
self._available_models = asyncio.Queue()
def init_models(self) -> None:
print(f"Initializing {len(self.models)} model(s)...")
self._queue = asyncio.Queue()
for model in self.models:
async_model = AsyncModel(model)
self._available_models.put_nowait(async_model)
self._model_semaphore = asyncio.Semaphore(len(self.models))
self._queue.put_nowait(AsyncModel(model))
async def invoke(self, content, logprobs=False, top_logprobs=0):
if self._available_models is None:
raise RuntimeError("Model pool not initialized. Call init_models() first.")
async_model = await self._available_models.get()
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
"""Acquire a model, invoke, release. Returns (text, logits)."""
if self._queue is None:
raise RuntimeError("Call init_models() first.")
async_model = await self._queue.get()
try:
if logprobs:
response, logprobs = await async_model.invoke(content, logprobs=logprobs, top_logprobs=top_logprobs)
return response, logprobs
else:
response = await async_model.invoke(content)
return response
return await async_model.invoke(messages)
finally:
self._available_models.put_nowait(async_model)
self._queue.put_nowait(async_model)
# -----------------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------------
def load_models(
models: List[str],
temperature: float = 0.0,
max_new_tokens: int = 1024,
enable_thinking: bool = False,
) -> AsyncModelPool:
"""Create and initialize a model pool from a list of model paths."""
pool = AsyncModelPool()
for path in models:
pool.add_model(
Model(
path,
temperature=temperature,
max_new_tokens=max_new_tokens,
enable_thinking=enable_thinking,
)
)
pool.init_models()
return pool

487
sc/core/recruiter_agent.py Normal file
View File

@@ -0,0 +1,487 @@
"""
RecruiterAgent: MAB-based example set selection for personalized ICL.
Components:
- self_certainty(logits): KL-div of LLM output from uniform (reward signal)
- borda_vote(results): aggregate answers across example sets
- FeatureIndex: loads index.json + embeddings.npy for O(1) lookup
- MABSelector: contextual bandit adapted from CASE/sample_case.py top_m_arm
- ArmPool: manages candidate example sets and feature vectors
- RecruiterAgent: orchestrates everything
Feature file schema:
<feature_root>/index.json + <feature_root>/embeddings.npy
index.json:
{
"version": "1.0",
"embedding_type": "sbert_metadata",
"embedding_dim": 384,
"samples": [
{"row": 0, "user_id": "001", "sample_id": 42, "label": "Hypertension", "session": "1"},
...
]
}
embeddings.npy: float32 array shape (N_samples, embedding_dim)
Switch embedding type by pointing to a different feature_root directory.
"""
from __future__ import annotations
import collections, json, math, os, random
from random import sample as random_sample
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
try:
import torch
import torch.nn.functional as F
_TORCH = True
except ImportError:
_TORCH = False
# ── Utilities ──────────────────────────────────────────────────────────────────
def self_certainty(logits) -> float:
"""KL divergence of LLM logits from uniform. Lower = more certain. Negate for MAB."""
if logits is None or not _TORCH:
return 0.0
logits = logits.squeeze(0).float()
log_probs = F.log_softmax(logits, dim=-1)
score = (-1.0 / logits.shape[-1]) * log_probs.sum(dim=-1) - math.log(logits.shape[-1])
return score.mean().item()
def borda_vote(
results: List[Tuple[int, Optional[Dict[str, Any]], float]],
valid_classes: Optional[List[str]] = None,
) -> Optional[str]:
"""
Borda-count vote over (arm_idx, response_dict, certainty_score).
Lower certainty_score = more certain = higher Borda rank.
"""
valid = [
(r[1]["ANSWER"], r[2]) for r in results
if r[1] is not None and "ANSWER" in r[1]
and (valid_classes is None or r[1]["ANSWER"] in valid_classes)
]
if not valid:
return None
sorted_v = sorted(valid, key=lambda x: x[1])
M = len(sorted_v)
borda: Dict[str, float] = {}
for rank, (ans, _) in enumerate(sorted_v):
borda[ans] = borda.get(ans, 0.0) + (M - rank)
top = max(borda.values())
cands = [a for a, s in borda.items() if s == top]
if len(cands) == 1:
return cands[0]
means = {a: float(np.mean([sc for x, sc in valid if x == a])) for a in cands}
return min(means, key=means.get)
# ── FeatureIndex ───────────────────────────────────────────────────────────────
class FeatureIndex:
"""Loads embeddings from disk; O(1) lookup by (user_id, sample_id)."""
def __init__(self, feature_root: str) -> None:
idx_p = os.path.join(feature_root, "index.json")
emb_p = os.path.join(feature_root, "embeddings.npy")
if not os.path.isfile(idx_p):
raise FileNotFoundError(f"index.json not found: {idx_p}")
if not os.path.isfile(emb_p):
raise FileNotFoundError(f"embeddings.npy not found: {emb_p}")
with open(idx_p, "r", encoding="utf-8") as f:
meta = json.load(f)
self.embedding_type: str = meta.get("embedding_type", "unknown")
self.embedding_dim: int = meta["embedding_dim"]
self.embeddings: np.ndarray = np.load(emb_p).astype(np.float32)
self.samples: List[Dict[str, Any]] = meta["samples"]
self._lookup: Dict[Tuple[str, int], int] = {
(str(e["user_id"]), int(e["sample_id"])): int(e["row"])
for e in self.samples
}
print(f"[FeatureIndex] {len(self.embeddings)} embeddings "
f"(dim={self.embedding_dim}, type={self.embedding_type})")
def get(self, user_id: str, sample_id: int) -> Optional[np.ndarray]:
row = self._lookup.get((str(user_id), int(sample_id)))
return self.embeddings[row] if row is not None else None
def get_user_embedding(self, user_id: str) -> Optional[np.ndarray]:
"""Mean embedding across all samples for a user."""
rows = [self.embeddings[e["row"]] for e in self.samples
if str(e["user_id"]) == str(user_id)]
return np.mean(rows, axis=0).astype(np.float32) if rows else None
# ── MABSelector ────────────────────────────────────────────────────────────────
class MABSelector:
"""
Contextual bandit for top-m arm identification.
Adapted from CASE / sample_case.py top_m_arm.
Key differences from original:
- Rewards injected via observe(arm, reward) — no LLM calls, no file IO
- update() does one rank-1 theta update per call
- No initialization() warm-up loop required
X: np.ndarray shape (embedding_dim, num_arms)
m: number of top arms to track (= queue_size)
"""
def __init__(
self, X: np.ndarray, m: int,
n_das: int = 5, delta: float = 0.05,
epsilon: float = 0.11, sigma: float = 0.5,
) -> None:
assert X.ndim == 2, "X must be (embedding_dim, num_arms)"
self.X = X.astype(np.float64)
self.N, self.K = X.shape
self.m = min(m, self.K)
self.n_das = n_das
self.delta = delta
self.epsilon = epsilon
self.sigma = sigma
self.arms = list(range(self.K))
self._it = -1
self._Bdi: Dict[Tuple[int, int], float] = {}
self._reset()
def _reset(self) -> None:
self.t = 0
self.rewards: List[float] = []
self.pulled_arms: List[int] = []
self.means = np.zeros(self.K)
self.na = np.zeros(self.K)
self.B_inv = np.eye(self.N, dtype=np.float64)
self.b = np.zeros(self.N, dtype=np.float64)
self.theta = np.random.normal(0, 1, size=(1, self.N))
self.J: List[int] = []
self.notJ: List[int] = list(range(self.K))
self.N_t: List[int] = []
self.best_arm: Optional[int] = None
self.worst_arm: Optional[int] = None
self.challenger: Optional[int] = None
self.c_t: Optional[int] = None
self.patience = 0
self.previous_J: List[int] = []
# ── math helpers ────────────────────────────────────────────────────────────
def _mn(self, x: np.ndarray, A: np.ndarray) -> float:
x = np.asarray(x).flatten()
return float(np.sqrt(max(float(x @ A @ x), 0.0)))
def _iu(self, A: np.ndarray, x: np.ndarray) -> np.ndarray:
"""Rank-1 inverse update: A - A x xT A / (1 + xT A x)."""
x = np.asarray(x).flatten()
d = 1.0 + self._mn(x, A) ** 2
return A - np.outer(A @ x, x @ A) / d
def _beta(self) -> float:
return math.log((math.log(max(self.t, 2)) + 1) / max(self.delta, 1e-12))
def _gap(self, i: int, j: Optional[int] = None) -> float:
xi = self.X[:, i] - (self.X[:, j] if j is not None else 0.0)
return float(self.theta.flatten() @ xi)
def _var(self, i: int, j: Optional[int] = None) -> float:
S = (self.sigma ** 2) * self.B_inv
xi = self.X[:, i] - (self.X[:, j] if j is not None else 0.0)
return self._mn(xi, S) * math.sqrt(2 * self._beta())
def _Bij(self, i: int, j: int) -> float:
t = max(self.t, 1)
if t != self._it:
self._Bdi = {}
self._it = t
k = (i, j)
if k not in self._Bdi:
self._Bdi[k] = self._gap(i, j) + self._var(i, j)
return self._Bdi[k]
def _randf(self, x: List[float], f) -> int:
arr = np.array(x, dtype=float)
val = f(arr)
c = np.argwhere(arr == val).flatten().tolist()
return random_sample(c, 1)[0]
def _mmax(self, x: List[float], m: int) -> List[int]:
x = list(x)
ids = []
for _ in range(min(m, len(x))):
idx = self._randf(x, np.max)
ids.append(idx)
x[idx] = -float("inf")
return ids
# ── sample step ─────────────────────────────────────────────────────────────
def _sel(self, nt: Optional[int]) -> Optional[int]:
if nt is None or self.c_t is None:
return None
if self.means[self.c_t] > self.means[nt]:
s = self.c_t
if s in self.N_t:
self.N_t.remove(s)
return s
return None
def _Jt(self) -> List[int]:
if self.t == 0:
return self._mmax(self.means.tolist(), self.m)
nt = self.worst_arm
sel = self._sel(nt)
self.previous_J = list(self.J)
if sel is not None:
self.J = [a for a in self.J if a != nt] + [sel]
if nt is not None and nt not in self.N_t:
self.N_t.append(nt)
order = np.argsort(self.means[self.J])[::-1].tolist()
self.J = [self.J[i] for i in order]
self.patience = 0 if self.J != self.previous_J else self.patience + 1
return self.J
def sample(self) -> List[int]:
"""Select next arm(s) to evaluate. Returns list of arm indices."""
self.J = self._Jt()
self.notJ = [a for a in self.arms if a not in self.J]
if self.t == 0 or not self.N_t:
n = min(self.n_das, len(self.notJ))
self.N_t = list(np.random.choice(self.notJ, n, replace=False)) if self.notJ else []
else:
Qt = list(np.random.choice(self.notJ, min(self.n_das, len(self.notJ)), replace=False)) if self.notJ else []
self.N_t = list({*self.N_t, *Qt} - set(self.J))
if self.N_t:
top_n = min(self.n_das, len(self.N_t))
tv = np.sort(self.means[self.N_t])[::-1][:top_n]
picked: List[int] = []
for mv in tv:
c = [a for a in self.N_t if self.means[a] == mv and a not in picked]
if c:
picked.append(random.choice(c))
self.N_t = picked
if self.N_t and self.J:
jm = self.means[self.J]
ntc = [a for a in self.J if self.means[a] == jm.min()]
self.worst_arm = random.choice(ntc)
bti = [self._Bij(a, self.worst_arm) for a in self.J]
self.best_arm = self.J[self._randf(bti, np.max)]
chi = [self._Bij(a, self.best_arm) for a in self.N_t]
self.challenger = self.N_t[self._randf(chi, np.max)]
nm = self.means[self.N_t]
c = [a for a in self.N_t if self.means[a] == nm.max()]
self.c_t = random.choice(c)
else:
self.worst_arm = self.J[0] if self.J else None
self.best_arm = self.challenger = self.c_t = None
# Greedy arm pull
if self.best_arm is not None and self.challenger is not None:
d = self.X[:, self.best_arm] - self.X[:, self.challenger]
u = [self._mn(d, self._iu(self.B_inv, self.X[:, i])) for i in self.arms]
return [self.arms[self._randf(u, np.min)]]
return [random.choice(self.arms)]
def observe(self, arm: int, reward: float) -> None:
"""Record reward for arm."""
self.rewards.append(reward)
self.pulled_arms.append(arm)
self.na[arm] += 1
self.t += 1
def update(self) -> None:
"""Rank-1 theta/B_inv update from the last observation."""
if not self.pulled_arms:
return
x = self.X[:, self.pulled_arms[-1]].flatten()
self.B_inv = self._iu(self.B_inv, x)
self.b += self.rewards[-1] * x
self.theta = (self.B_inv @ self.b).reshape(1, self.N)
self.means = (self.theta @ self.X).flatten()
def recommend(self, n: int) -> List[int]:
"""Return n arm indices with highest estimated mean rewards."""
return self._mmax(self.means.tolist(), min(n, self.K))
def stopping_rule(self) -> bool:
if None in (self.best_arm, self.worst_arm, self.challenger):
return False
return round(self._Bij(self.challenger, self.best_arm), 2) <= self.epsilon
# ── ArmPool ────────────────────────────────────────────────────────────────────
#[TODO] Add way to sample arms from the larger pool with replacement, instead of fixing on to num_arms arms.
class ArmPool:
"""
Manages candidate example sets (arms).
Each arm = list of sample dicts (N-way x k_shot).
Arm feature = mean embedding of its constituent samples.
"""
def __init__(
self,
source_samples: List[Dict[str, Any]],
classes: List[str],
k_shot: int,
feature_index: Optional[FeatureIndex],
num_arms: int,
rng: np.random.Generator,
) -> None:
self.classes = classes
self.k_shot = k_shot
self.feature_index = feature_index
self._rng = rng
self._by: Dict[str, List[Dict[str, Any]]] = {c: [] for c in classes}
for s in source_samples:
lbl = s.get("label")
if lbl in self._by:
self._by[lbl].append(s)
self.arms: List[List[Dict[str, Any]]] = []
self.arm_feats: List[Optional[np.ndarray]] = []
for _ in range(num_arms):
arm, feat = self._make()
self.arms.append(arm)
self.arm_feats.append(feat)
print(f"[ArmPool] {len(self.arms)} arms ({len(classes)}-way {k_shot}-shot)")
def _make(self) -> Tuple[List[Dict[str, Any]], Optional[np.ndarray]]:
s: List[Dict[str, Any]] = []
for c in self.classes:
pool = self._by.get(c, [])
if pool:
n = min(self.k_shot, len(pool))
idxs = self._rng.choice(len(pool), size=n, replace=False)
s.extend([pool[i] for i in idxs])
return s, self._feat(s)
def _feat(self, samples: List[Dict[str, Any]]) -> Optional[np.ndarray]:
if self.feature_index is None:
return None
vecs = []
for s in samples:
v = self.feature_index.get(
str(s.get("user_id", "")),
int(s.get("sample_id", s.get("idx", -1))),
)
if v is not None:
vecs.append(v)
return np.mean(vecs, axis=0).astype(np.float32) if vecs else None
def get_arm(self, i: int) -> List[Dict[str, Any]]:
return self.arms[i]
def build_X(self) -> np.ndarray:
"""Feature matrix X of shape (embedding_dim, num_arms). Falls back to identity."""
feats = [f for f in self.arm_feats if f is not None]
if not feats:
return np.eye(len(self.arms), dtype=np.float32)
X = np.zeros((feats[0].shape[0], len(self.arms)), dtype=np.float32)
for i, f in enumerate(self.arm_feats):
if f is not None:
X[:, i] = f
return X
# ── RecruiterAgent ────────────────────────────────────────────────────────────
class RecruiterAgent:
"""
MAB-driven recruiter for personalized ICL example set selection.
Usage:
recruiter = RecruiterAgent(source_samples, classes, ...)
initial = recruiter.recruit(queue_size) # cold start (random)
for test_sample in target_stream:
results = [(arm_idx, response_dict, score), ...]
recruiter.update(results) # MAB update lives here
new_sets = recruiter.recruit(L) # MAB-guided new sets
queue.update_with_recruiter(results, new_sets)
MAB update logic inside update():
1. reward = -score (negate: lower certainty = better = higher reward)
2. mab.observe(arm, reward)
3. mab.update() -> rank-1 theta/B_inv update
4. mab.sample() -> select next arms to explore (stored for next recruit())
"""
def __init__(
self,
source_samples: List[Dict[str, Any]],
classes: List[str],
feature_index: Optional[FeatureIndex] = None,
target_user_id: Optional[str] = None,
k_shot: int = 1,
queue_size: int = 5,
num_arms: int = 50,
n_das: int = 5,
delta: float = 0.05,
epsilon: float = 0.11,
sigma: float = 0.5,
seed: int = 42,
) -> None:
self.classes = classes
self.k_shot = k_shot
self.queue_size = queue_size
self.target_user_id = target_user_id
rng = np.random.default_rng(seed)
random.seed(seed)
np.random.seed(seed)
self.pool = ArmPool(source_samples, classes, k_shot, feature_index, num_arms, rng)
self.mab = MABSelector(
self.pool.build_X(), m=queue_size,
n_das=n_das, delta=delta, epsilon=epsilon, sigma=sigma,
)
self._next: Optional[List[int]] = None
self._initialized = False
print(f"[RecruiterAgent] {len(self.pool.arms)} arms, "
f"queue={queue_size}, classes={list(set(classes))}")
def recruit(self, n: int) -> List[Tuple[int, List[Dict[str, Any]]]]:
"""
Return n (arm_idx, example_set) pairs for insertion into the queue.
Cold start: random. Subsequent calls: MAB-recommended.
"""
if not self._initialized:
indices = random.sample(range(len(self.pool.arms)), min(n, len(self.pool.arms)))
self._initialized = True
elif self._next is not None:
indices = self._next[:n]
if len(indices) < n:
rest = [i for i in range(len(self.pool.arms)) if i not in indices]
indices += random.sample(rest, min(n - len(indices), len(rest)))
else:
indices = self.mab.recommend(n)
return [(i, self.pool.get_arm(i)) for i in indices]
def update(self, results: List[Tuple[int, Optional[Dict[str, Any]], float]]) -> None:
"""
Update MAB from (arm_idx, response_dict, self_certainty_score).
Steps:
1. reward = -score (lower certainty = better = higher reward)
2. mab.observe(arm, reward)
3. mab.update() -> rank-1 theta/B_inv update
4. mab.sample() -> selects next arms to explore
"""
if not results:
return
for arm_idx, _, score in results:
if 0 <= arm_idx < len(self.pool.arms):
self.mab.observe(arm_idx, -float(score))
self.mab.update()
try:
next_cands = self.mab.sample()
top = self.mab.recommend(self.queue_size)
self._next = list(dict.fromkeys(next_cands + top))
except Exception as e:
print(f"[RecruiterAgent] MAB sample error ({e}), using recommend()")
self._next = self.mab.recommend(self.queue_size)

20
sc/logger.py Normal file
View File

@@ -0,0 +1,20 @@
import os
import yaml
from datetime import datetime
class Logger:
def __init__(self, log_path: str):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.log_path = f"{log_path}_{timestamp}"
os.makedirs(self.log_path, exist_ok=True)
def log(self, message: str):
print(message)
with open(self.log_path, "a", encoding="utf-8") as f:
f.write(message + "\n")
def log_config(self, config: dict):
dumped = yaml.dump(config, default_flow_style=False, sort_keys=False)
msg = "Config:\n" + dumped
self.log(msg)

View File

@@ -1,31 +1,149 @@
"""
HuggingFace Causal LM wrapper for inference with logits.
Loads a model via transformers, generates text, and returns both
the generated text and full logits via a forward pass.
Message format: [{"role": "user"|"assistant"|"system", "content": str}, ...]
"""
from __future__ import annotations
import asyncio
from typing import Dict, List, Optional, Tuple
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
model_dir = "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
device = "cuda" if torch.cuda.is_available() else "cpu"
# {"role": "user"|"assistant"|"system", "content": str}
ChatMessage = Dict[str, str]
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
).eval()
# -----------------------------------------------------------------------------
# Model
# -----------------------------------------------------------------------------
class Model:
"""HuggingFace causal LM. Returns generated text + full logits."""
def __init__(self, model_path: str, temperature: float = 0.0, max_new_tokens: int = 1024) -> None:
self._tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
self._model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()
self.temperature = temperature
self.max_new_tokens = max_new_tokens
def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
"""
Generate a reply and return (text, logits).
Returns:
text: Generated text string.
logits: Full logits tensor (1, seq_len, vocab_size) on CPU.
seq_len covers prompt + generated tokens.
"""
# Build prompt from chat messages
prompt = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
inputs = self._tokenizer(prompt, return_tensors="pt").to(self._model.device)
prompt_len = inputs["input_ids"].shape[1]
# Generate
with torch.no_grad():
gen_out = self._model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=self.temperature > 0,
temperature=self.temperature if self.temperature > 0 else 1.0,
)
# Decode generated tokens
gen_ids = gen_out[0][prompt_len:]
text = self._tokenizer.decode(gen_ids, skip_special_tokens=True)
# Forward pass on full sequence to get logits
with torch.no_grad():
logits = self._model(gen_out).logits # (1, seq_len, vocab_size)
return text, logits.cpu()
# -----------------------------------------------------------------------------
# Async wrappers
# -----------------------------------------------------------------------------
class AsyncModel:
"""Runs Model.invoke in a thread executor for async usage."""
def __init__(self, model: Model) -> None:
self.model = model
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None, lambda: self.model.invoke(messages),
)
class AsyncModelPool:
"""Round-robin pool of async models."""
def __init__(self) -> None:
self.models: List[Model] = []
self._queue: Optional[asyncio.Queue[AsyncModel]] = None
def add_model(self, model: Model) -> None:
self.models.append(model)
def init_models(self) -> None:
print(f"Initializing {len(self.models)} model(s)...")
self._queue = asyncio.Queue()
for model in self.models:
self._queue.put_nowait(AsyncModel(model))
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
"""Acquire a model, invoke, release. Returns (text, logits)."""
if self._queue is None:
raise RuntimeError("Call init_models() first.")
async_model = await self._queue.get()
try:
return await async_model.invoke(messages)
finally:
self._queue.put_nowait(async_model)
# -----------------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------------
def load_models(
models: List[str],
temperature: float = 0.0,
max_new_tokens: int = 1024,
) -> AsyncModelPool:
"""Create and initialize a model pool from a list of model paths."""
pool = AsyncModelPool()
for path in models:
pool.add_model(
Model(path, temperature=temperature, max_new_tokens=max_new_tokens)
)
pool.init_models()
return pool
# test
model = Model("/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b")
messages = [
{"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
]
# Convert chat messages -> a single prompt string using the model's chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out = model(**inputs)
logits = out.logits # shape: (batch=1, seq_len, vocab_size)
print("logits shape:", logits.shape)
text, logits = model.invoke(messages)
print(text)
print(logits.shape)

View File

@@ -0,0 +1,99 @@
"""
Build unified feature index from PPGBP per-subject .npy embeddings.
Output: <output_dir>/index.json + <output_dir>/embeddings.npy
Usage:
python -m sc.preprocess.build_feature_index --embedding_root $ROOT --output_dir ./features/sbert_metadata
"""
from __future__ import annotations
import argparse
import json
import os
from typing import Any, Dict, List, Optional
import numpy as np
def _load_labels(base_dir: Optional[str]) -> Dict[int, str]:
"""Load subject_id -> label from xlsx. Prefers Hypertension column (PPGBP), else class/Class."""
if not base_dir:
return {}
try:
import pandas as pd
except ImportError:
return {}
xlsx = None
for f in os.listdir(base_dir):
if f.endswith(".xlsx"):
xlsx = os.path.join(base_dir, f)
break
if not xlsx or not os.path.isfile(xlsx):
return {}
df = pd.read_excel(xlsx, header=1)
if "subject_ID" not in df.columns:
return {}
label_col = None
for c in ("Hypertension", "class", "Class"):
if c in df.columns:
label_col = c
break
out = {}
for _, row in df.iterrows():
sid = int(row["subject_ID"])
out[sid] = str(row.get(label_col, "unknown")).strip() if label_col else "unknown"
return out
def build_index(
embedding_root: str,
output_dir: str,
embedding_type: str = "sbert_metadata",
base_dir: Optional[str] = None,
) -> None:
if not os.path.isdir(embedding_root):
raise FileNotFoundError(embedding_root)
subject_ids = []
for f in sorted(os.listdir(embedding_root)):
if not f.endswith(".npy"):
continue
try:
subject_ids.append(int(f.replace(".npy", "")))
except ValueError:
continue
subject_ids = sorted(subject_ids)
if not subject_ids:
raise ValueError("No .npy files in " + embedding_root)
labels = _load_labels(base_dir)
embeddings_list = []
samples: List[Dict[str, Any]] = []
for row, sid in enumerate(subject_ids):
emb = np.load(os.path.join(embedding_root, f"{sid}.npy")).astype(np.float32)
embeddings_list.append(emb)
samples.append({
"row": row,
"user_id": str(sid),
"sample_id": sid,
"label": labels.get(sid, "unknown"),
"session": "1",
})
embeddings = np.stack(embeddings_list, axis=0)
dim = int(embeddings.shape[1])
meta = {"version": "1.0", "embedding_type": embedding_type, "embedding_dim": dim, "samples": samples}
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "index.json"), "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
np.save(os.path.join(output_dir, "embeddings.npy"), embeddings)
print(f"[build_feature_index] Wrote {len(samples)} samples, dim={dim}")
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--embedding_root", required=True, help="Directory of per-subject .npy files")
ap.add_argument("--output_dir", required=True, help="Output dir for index.json and embeddings.npy")
ap.add_argument("--embedding_type", default="sbert_metadata")
ap.add_argument("--base_dir", default=None, help="Optional PPGBP data dir (xlsx) for labels")
a = ap.parse_args()
build_index(a.embedding_root, a.output_dir, a.embedding_type, a.base_dir)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,134 @@
"""
Build PPGBP embeddings under PPGBP_EMBEDDINGS_ROOT with two top-level dirs:
metadata_embs/<emb_type>/<subject_id>.npy
sig_embs/<emb_type>/<subject_id>.npy
For now only metadata embeddings are produced: for each subject, take the xlsx row,
vectorize (one-hot for Sex and Hypertension, numeric for the rest) and save.
Usage:
python -m sc.preprocess.build_ppgbp_embeddings \\
--embeddings_root /path/to/ppgbp_embeddings \\
--data_dir /path/to/ppgbp_data \\
--metadata_emb_type onehot
Then build the feature index (for downstream RecruiterAgent etc.):
python -m sc.preprocess.build_feature_index \\
--embedding_root $PPGBP_EMBEDDINGS_ROOT/metadata_embs/onehot \\
--output_dir $DATA_INDEX_ROOT/metadata_onehot \\
--embedding_type metadata_onehot \\
--base_dir $PPGBP_DATA_DIR
"""
from __future__ import annotations
import argparse
import os
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
# Xlsx column names (header=1)
SUBJECT_ID_COL = "subject_ID"
SEX_COL = "Sex(M/F)"
AGE_COL = "Age(year)"
HEIGHT_COL = "Height(cm)"
WEIGHT_COL = "Weight(kg)"
SBP_COL = "Systolic Blood Pressure(mmHg)"
DBP_COL = "Diastolic Blood Pressure(mmHg)"
HR_COL = "Heart Rate(b/m)"
BMI_COL = "BMI(kg/m^2)"
HYP_COL = "Hypertension"
NUMERIC_COLS = [AGE_COL, HEIGHT_COL, WEIGHT_COL, SBP_COL, DBP_COL, HR_COL, BMI_COL]
CATEGORICAL_COLS = [SEX_COL, HYP_COL]
def _find_xlsx(data_dir: str) -> str:
for f in os.listdir(data_dir):
if f.endswith(".xlsx"):
return os.path.join(data_dir, f)
raise FileNotFoundError(f"No .xlsx in {data_dir}")
def _onehot_and_numeric(df: pd.DataFrame) -> Tuple[np.ndarray, List[str], Dict[str, List[str]]]:
"""
One-hot encode categoricals (Sex, Hypertension) and stack with numeric columns.
Returns (matrix, column_names, categories_used).
"""
out_cols: List[str] = []
pieces: List[np.ndarray] = []
categories_used: Dict[str, List[str]] = {}
for col in CATEGORICAL_COLS:
if col not in df.columns:
continue
vals = df[col].fillna("unknown").astype(str).str.strip()
uniq = sorted(vals.unique())
categories_used[col] = uniq
for u in uniq:
out_cols.append(f"{col}__{u}")
onehot = (vals.values.reshape(-1, 1) == np.array(uniq)).astype(np.float32)
pieces.append(onehot)
for col in NUMERIC_COLS:
if col not in df.columns:
continue
out_cols.append(col)
vec = df[col].fillna(0).astype(np.float32).values.reshape(-1, 1)
pieces.append(vec)
if not pieces:
raise ValueError("No columns found; check xlsx has expected column names")
X = np.hstack(pieces)
return X, out_cols, categories_used
def create_embedding_dirs(embeddings_root: str, metadata_emb_type: str, sig_emb_type: str = "placeholder") -> None:
"""Create metadata_embs/<emb_type>/ and sig_embs/<emb_type>/."""
os.makedirs(embeddings_root, exist_ok=True)
meta_dir = os.path.join(embeddings_root, "metadata_embs", metadata_emb_type)
sig_dir = os.path.join(embeddings_root, "sig_embs", sig_emb_type)
os.makedirs(meta_dir, exist_ok=True)
os.makedirs(sig_dir, exist_ok=True)
def build_metadata_embeddings_onehot(
data_dir: str,
embeddings_root: str,
metadata_emb_type: str = "onehot",
) -> None:
"""
Read xlsx from data_dir, vectorize each subject row (one-hot + numeric), save as
metadata_embs/<metadata_emb_type>/<subject_id>.npy.
"""
xlsx_path = _find_xlsx(data_dir)
df = pd.read_excel(xlsx_path, header=1)
if SUBJECT_ID_COL not in df.columns:
raise ValueError(f"xlsx must have column {SUBJECT_ID_COL}")
X, _cols, _cats = _onehot_and_numeric(df)
subject_ids = df[SUBJECT_ID_COL].astype(int).values
out_dir = os.path.join(embeddings_root, "metadata_embs", metadata_emb_type)
os.makedirs(out_dir, exist_ok=True)
for i, sid in enumerate(subject_ids):
path = os.path.join(out_dir, f"{int(sid)}.npy")
np.save(path, X[i].astype(np.float32))
print(f"[build_ppgbp_embeddings] Wrote {len(subject_ids)} metadata embeddings to {out_dir} (dim={X.shape[1]})")
def main() -> None:
ap = argparse.ArgumentParser(description="Build PPGBP metadata (and optionally signal) embeddings under embeddings_root.")
ap.add_argument("--embeddings_root", required=True, help="Root for metadata_embs/ and sig_embs/ (e.g. PPGBP_EMBEDDINGS_ROOT)")
ap.add_argument("--data_dir", required=True, help="PPGBP data dir containing the xlsx file")
ap.add_argument("--metadata_emb_type", default="onehot", help="Subdir name under metadata_embs/ (e.g. onehot)")
ap.add_argument("--sig_emb_type", default="placeholder", help="Subdir under sig_embs/ (created empty for now)")
args = ap.parse_args()
create_embedding_dirs(args.embeddings_root, args.metadata_emb_type, args.sig_emb_type)
build_metadata_embeddings_onehot(args.data_dir, args.embeddings_root, args.metadata_emb_type)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,76 @@
"""
One-shot: build PPGBP one-hot metadata embeddings and then package them into
the feature index at DATA_INDEX_ROOT for downstream use (RecruiterAgent, etc.).
Step 1: Create metadata_embs/onehot/ and sig_embs/ under PPGBP_EMBEDDINGS_ROOT,
write one-hot metadata vectors per subject.
Step 2: Run build_feature_index to produce index.json + embeddings.npy at
DATA_INDEX_ROOT (compatible with FeatureIndex and RecruiterAgent).
Usage:
export PPGBP_EMBEDDINGS_ROOT=/scratch/.../embeddings_full/ppgbp
export PPGBP_DATA_DIR=/path/to/ppgbp_data # dir with xlsx and 0_subject/
# or: export PPGBP_DATA_ROOT=/path/to/data/tsllm (same as data dir if xlsx lives there)
export DATA_INDEX_ROOT=/path/to/feature_index_output
python -m sc.preprocess.package_ppgbp_embeddings
Or with explicit args:
python -m sc.preprocess.package_ppgbp_embeddings \\
--embeddings_root $PPGBP_EMBEDDINGS_ROOT \\
--data_dir $PPGBP_DATA_DIR \\
--index_output_dir $DATA_INDEX_ROOT \\
--metadata_emb_type onehot \\
--embedding_type metadata_onehot
"""
from __future__ import annotations
import argparse
import os
import sys
def main() -> None:
ap = argparse.ArgumentParser(description="Build PPGBP metadata embeddings and package feature index.")
ap.add_argument("--embeddings_root", default=None, help="Default: PPGBP_EMBEDDINGS_ROOT env")
ap.add_argument("--data_dir", default=None, help="PPGBP data dir (xlsx + 0_subject/); default: PPGBP_DATA_DIR or PPGBP_DATA_ROOT env")
ap.add_argument("--index_output_dir", default=None, help="Where to write index.json + embeddings.npy; default: DATA_INDEX_ROOT env")
ap.add_argument("--metadata_emb_type", default="onehot")
ap.add_argument("--embedding_type", default="metadata_onehot", help="embedding_type string in index.json")
args = ap.parse_args()
embeddings_root = args.embeddings_root or os.environ.get("PPGBP_EMBEDDINGS_ROOT")
data_dir = args.data_dir or os.environ.get("PPGBP_DATA_DIR") or os.environ.get("PPGBP_DATA_ROOT") or embeddings_root
index_output_dir = args.index_output_dir or os.environ.get("DATA_INDEX_ROOT")
if not embeddings_root:
print("Set PPGBP_EMBEDDINGS_ROOT or pass --embeddings_root", file=sys.stderr)
sys.exit(1)
if not os.path.isdir(data_dir):
print(f"Data dir not found: {data_dir}", file=sys.stderr)
sys.exit(1)
if not index_output_dir:
print("Set DATA_INDEX_ROOT or pass --index_output_dir", file=sys.stderr)
sys.exit(1)
from sc.preprocess.build_ppgbp_embeddings import (
create_embedding_dirs,
build_metadata_embeddings_onehot,
)
create_embedding_dirs(embeddings_root, args.metadata_emb_type)
build_metadata_embeddings_onehot(data_dir, embeddings_root, args.metadata_emb_type)
embedding_root_for_index = os.path.join(embeddings_root, "metadata_embs", args.metadata_emb_type)
from sc.preprocess.build_feature_index import build_index
build_index(
embedding_root=embedding_root_for_index,
output_dir=index_output_dir,
embedding_type=args.embedding_type,
base_dir=data_dir,
)
print(f"[package_ppgbp_embeddings] Feature index written to {index_output_dir}")
if __name__ == "__main__":
main()

View File

@@ -30,7 +30,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.queue import Queue
from sc.core.example_queue import Queue
from sc.core.agent_pool import AgentPool
from sc import debug_log
@@ -425,7 +425,7 @@ def main(
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
num_ctx=config.get("num_ctx", 15000),
max_new_tokens=config.get("max_new_tokens", 1024),
)
# Run experiment

View File

@@ -32,7 +32,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.queue import Queue
from sc.core.example_queue import Queue
from sc.core.agent_pool import AgentPool
from sc import debug_log
@@ -435,7 +435,7 @@ def main(
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
num_ctx=config.get("num_ctx", 15000),
max_new_tokens=config.get("max_new_tokens", 1024),
)
# Run experiment

390
sc/run_recruiter.py Normal file
View File

@@ -0,0 +1,390 @@
"""
Recruiter-based Self-Consistency runner for PPGBP.
Uses RecruiterAgent (MAB) to select example sets, self_certainty(logits) as reward,
and borda_vote for aggregation. Requires PPGBPLoader and optional feature index.
Usage:
python -m sc.run_recruiter sc/config/ppgbp_recruiter.yaml
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
from collections import deque
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import yaml
from fire import Fire
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.core.data_loader import PPGBPLoader, get_loaders
from sc.core.example_queue import Queue
from sc.core.model import Model, load_models
from sc.core.recruiter_agent import RecruiterAgent, borda_vote, self_certainty
from sc.core.recruiter_agent import FeatureIndex
def _ppgbp_sample_to_dict(metadata: Dict, signal, subject_id: int, label_key: str = "label") -> Dict[str, Any]:
"""Turn (metadata, signal) into a sample dict with user_id, sample_id, label, features."""
# Try to find label in metadata using label_key, then "hypertension", then "Hypertension", then "class", then "Class"
label = str(metadata.get(label_key))
if label == "None" or label == "unknown":
for key in ["hypertension", "Hypertension", "class", "Class"]:
val = metadata.get(key)
if val is not None:
label = str(val)
break
if label == "None":
label = "unknown"
features = {k: str(v) for k, v in metadata.items() if k != label_key and k.lower() != "label"}
return {
"user_id": str(subject_id),
"sample_id": subject_id,
"idx": subject_id,
"label": label,
"features": features,
}
def _build_prompt(system_msg: str, task_info: str, classes_info: List[str], test_sample: Dict, examples: List[Dict]) -> List[Dict[str, str]]:
"""Build chat messages for the model."""
parts = [f"**Task**: {task_info}\n\n**Classes**: {', '.join(classes_info)}\n\n"]
parts.append("**Examples**\n")
for ex in examples:
parts.append(f"*Example of {ex['label']}*:\n")
for k, v in (ex.get("features") or {}).items():
parts.append(f" - {k}: {v}\n")
parts.append("\n**Current sample**\n")
for k, v in (test_sample.get("features") or {}).items():
parts.append(f" - {k}: {v}\n")
parts.append(f"\nRespond in JSON: {{\"REASON\": \"...\", \"CONFIDENCE\": 0.0-1.0, \"ANSWER\": \"<class>\"}}")
content = "".join(parts)
return [
{"role": "system", "content": system_msg},
{"role": "user", "content": content},
]
def _parse_answer(text: str, classes: List[str]) -> Optional[str]:
"""Extract ANSWER from JSON in text."""
try:
start = text.find("{")
end = text.rfind("}") + 1
if start >= 0 and end > start:
obj = json.loads(text[start:end])
ans = (obj.get("ANSWER") or "").strip()
if ans in classes:
return ans
except Exception:
pass
return None
def make_gt_scorer(
num_arms: int,
best_arm_id: int = 20,
sigma: float = 5.0,
) -> Callable[[int], float]:
"""
Ground-truth: deterministic. The distribution *over arms* is normal: reward(arm_id)
is a Gaussian in arm index with peak at best_arm_id. So reward(i) = exp(-(i - best)^2 / (2 sigma^2)).
Score = -reward so lower = better. No randomness.
"""
best_arm_id = max(0, min(best_arm_id, num_arms - 1))
sigma = max(1e-6, float(sigma))
def scorer(arm_id: int) -> float:
if arm_id < 0 or arm_id >= num_arms:
return 0.0
reward = np.exp(-((arm_id - best_arm_id) ** 2) / (2 * sigma**2))
return -float(reward)
return scorer
async def run_single_user(
config: Dict[str, Any],
model_pool: Optional[Any],
test_loader: Optional[PPGBPLoader],
train_loader: PPGBPLoader,
seed: int,
) -> List[Dict[str, Any]]:
"""Run recruiter pipeline. If model_pool is None, use ground-truth scorer (arm_id -> score) for MAB convergence."""
queue_size = config.get("queue_size", 5)
recruit_size = config.get("recruit_size", 2)
n_way = config.get("n_way", 2)
k_shot = config.get("k_shot", 1)
task_info = config.get("task_info", "Classify the subject.")
classes_info = config.get("classes_info", ["Normal", "Hypertension"])
system_msg = config.get("system_message", "You are a helpful assistant. Respond in JSON with REASON, CONFIDENCE, ANSWER.")
use_gt_only = model_pool is None
np.random.seed(seed)
train_loader._indices = np.arange(len(train_loader.subject_ids))
if train_loader.shuffle:
train_loader.rng.shuffle(train_loader._indices)
example_dataset = []
for idx in train_loader._indices:
sid = train_loader.subject_ids[idx]
meta = train_loader._get_metadata_dict(sid)
signal = train_loader._load_signal(sid)
example_dataset.append(_ppgbp_sample_to_dict(meta, signal, int(sid)))
if len(example_dataset) < n_way * k_shot:
print("[run_recruiter] Not enough train samples")
return []
classes = list({s["label"] for s in example_dataset})
if not classes:
classes = classes_info
class_indices = {c: [i for i, s in enumerate(example_dataset) if s["label"] == c] for c in classes}
for c in classes:
if not class_indices[c]:
class_indices[c] = [0]
dataset_idx = {(str(s["user_id"]), s["sample_id"]): i for i, s in enumerate(example_dataset)}
feature_index = None
if config.get("feature_root") and os.path.isdir(config["feature_root"]):
try:
feature_index = FeatureIndex(config["feature_root"])
except Exception as e:
print(f"[run_recruiter] Feature index load failed: {e}")
recruiter = RecruiterAgent(
source_samples=example_dataset,
classes=classes,
feature_index=feature_index,
target_user_id=None,
k_shot=k_shot,
queue_size=queue_size,
num_arms=config.get("num_arms", 50),
n_das=config.get("n_das", 5),
delta=config.get("delta", 0.05),
epsilon=config.get("epsilon", 0.11),
seed=seed,
)
num_arms = len(recruiter.pool.arms)
if use_gt_only:
gt_steps = config.get("gt_steps", 200)
gt_best_arm_id = config.get("gt_best_arm_id", 20)
gt_sigma = config.get("gt_sigma", 5.0)
gt_scorer = make_gt_scorer(num_arms, best_arm_id=gt_best_arm_id, sigma=gt_sigma)
print(f"[run_recruiter] GT-only mode: {gt_steps} steps, best_arm_id={gt_best_arm_id}, sigma={gt_sigma}")
initial_sets = recruiter.recruit(queue_size)
initial_cases = []
slot_arms = []
for arm_idx, ex_set in initial_sets:
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
if key in dataset_idx:
case.append(dataset_idx[key])
else:
for i, s in enumerate(example_dataset):
if s.get("user_id") == str(d.get("user_id")) and s.get("sample_id") == d.get("sample_id"):
case.append(i)
break
if len(case) >= n_way:
initial_cases.append(case[:n_way])
slot_arms.append(arm_idx)
while len(initial_cases) < queue_size and len(initial_sets) > len(initial_cases):
arm_idx, ex_set = initial_sets[len(initial_cases)]
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
case.append(dataset_idx.get(key, 0))
if len(case) < n_way:
case.extend([0] * (n_way - len(case)))
initial_cases.append(case[:n_way])
slot_arms.append(arm_idx)
ex_queue = Queue(class_indices, queue_size)
ex_queue._queue = deque(initial_cases[:queue_size], maxlen=queue_size)
for case in ex_queue._queue:
ex_queue._register_stats(case)
slot_arms = slot_arms[:queue_size]
results = []
processed = 0
cumulative_correct = 0
if use_gt_only:
# Loop: query gt_scorer(arm_id) for each slot, update MAB, recruit, update queue
for step in range(gt_steps):
ex_queue.set_current_time(step)
queue_cases = list(ex_queue._queue)
run_results = []
for slot_i in range(len(slot_arms)):
arm_idx = slot_arms[slot_i]
score = gt_scorer(arm_idx)
response_dict = {"ANSWER": None, "REASON": "gt", "CONFIDENCE": 0.5}
run_results.append((arm_idx, response_dict, score))
recruiter.update(run_results)
new_sets = recruiter.recruit(recruit_size)
new_cases = []
new_arm_ids = []
for arm_idx, ex_set in new_sets:
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
case.append(dataset_idx.get(key, 0))
if len(case) < n_way:
case.extend([0] * (n_way - len(case)))
new_cases.append(case[:n_way])
new_arm_ids.append(arm_idx)
ex_queue.update_with_recruiter(run_results, new_cases, recruit_size)
scores_with_idx = [(i, run_results[i][2]) for i in range(len(run_results))]
scores_with_idx.sort(key=lambda x: x[1], reverse=True)
to_evict = {scores_with_idx[j][0] for j in range(min(recruit_size, len(scores_with_idx)))}
kept = [i for i in range(len(slot_arms)) if i not in to_evict]
slot_arms = [slot_arms[i] for i in kept] + new_arm_ids[:recruit_size]
results.append({
"sample_idx": step,
"answer": None,
"ground_truth": None,
"is_correct": None,
"gt_step": step,
"queue_arms": list(slot_arms),
"step_rewards": [r[2] for r in run_results],
})
print(f"[GT Summary] Final queue arms: {slot_arms}")
return results
# Model mode: iterate test loader, call LLM, self_certainty, borda_vote
for step, (metadata, signal) in enumerate(test_loader):
test_sample = _ppgbp_sample_to_dict(metadata, signal, int(metadata.get("subject_ID", step)))
ex_queue.set_current_time(processed)
queue_cases = list(ex_queue._queue)
run_results = []
for slot_i, case in enumerate(queue_cases):
examples = [example_dataset[i] for i in case if i < len(example_dataset)]
if len(examples) < n_way:
continue
messages = _build_prompt(system_msg, task_info, classes_info, test_sample, examples)
text, logits = await model_pool.invoke(messages)
score = self_certainty(logits)
parsed = _parse_answer(text, classes_info)
response_dict = {"ANSWER": parsed, "REASON": text[:200], "CONFIDENCE": 0.5}
arm_idx = slot_arms[slot_i] if slot_i < len(slot_arms) else slot_i
run_results.append((arm_idx, response_dict, score))
if not run_results:
processed += 1
continue
answer = borda_vote(run_results, valid_classes=classes_info)
gt = test_sample.get("label", "")
is_correct = answer == gt
cumulative_correct += 1 if is_correct else 0
acc = cumulative_correct / (processed + 1)
recruiter.update(run_results)
new_sets = recruiter.recruit(recruit_size)
new_cases = []
new_arm_ids = []
for arm_idx, ex_set in new_sets:
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
case.append(dataset_idx.get(key, 0))
if len(case) < n_way:
case.extend([0] * (n_way - len(case)))
new_cases.append(case[:n_way])
new_arm_ids.append(arm_idx)
ex_queue.update_with_recruiter(run_results, new_cases, recruit_size)
scores_with_idx = [(i, run_results[i][2]) for i in range(len(run_results))]
scores_with_idx.sort(key=lambda x: x[1], reverse=True)
to_evict = {scores_with_idx[j][0] for j in range(min(recruit_size, len(scores_with_idx)))}
kept = [i for i in range(len(slot_arms)) if i not in to_evict]
slot_arms = [slot_arms[i] for i in kept] + new_arm_ids[:recruit_size]
results.append({
"sample_idx": processed,
"answer": answer,
"ground_truth": gt,
"is_correct": is_correct,
"cumulative_accuracy": acc,
})
processed += 1
return results
def _expand_env_in_config(obj: Any) -> Any:
"""Recursively expand $VAR and ${VAR} in config strings."""
if isinstance(obj, dict):
return {k: _expand_env_in_config(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_expand_env_in_config(v) for v in obj]
if isinstance(obj, str):
return os.path.expandvars(obj)
return obj
def main(config_path: str) -> None:
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
config = _expand_env_in_config(config)
if not config.get("model_path") and config.get("models"):
config["model_path"] = config["models"][0] if isinstance(config["models"], list) else config["models"]
model_path = config.get("model_path") or ""
if isinstance(model_path, str):
model_path = model_path.strip()
use_gt_only = not model_path
if use_gt_only:
config["use_gt_only"] = True
print("[run_recruiter] model_path is null: using ground-truth scorer (arm_id -> score) for MAB convergence")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if "log_path" in config:
config["log_path"] = f"{config['log_path']}_{timestamp}"
pool = None
if not use_gt_only:
pool = load_models(
[model_path],
temperature=config.get("temperature", 0.7),
max_new_tokens=config.get("max_new_tokens", 256),
)
data_path = config["data_path"]
train_loader, test_loader = get_loaders(data_path, seed=config.get("seed", 42))
if use_gt_only:
test_loader = None # not used in GT loop
seeds = range(config.get("num_seeds", 1))
all_results = []
for seed in seeds:
res = asyncio.run(run_single_user(config, pool, test_loader, train_loader, seed))
all_results.extend(res)
if use_gt_only:
print(f"GT mode: completed {len(all_results)} steps")
else:
correct = sum(1 for r in all_results if r.get("is_correct"))
total = len(all_results)
acc = correct / total if total else 0
print(f"Accuracy: {correct}/{total} = {acc:.4f}")
log_path = config.get("log_path", "./logs/ppgbp_recruiter")
os.makedirs(log_path, exist_ok=True)
with open(os.path.join(log_path, "results.json"), "w") as f:
json.dump({
"use_gt_only": use_gt_only,
"accuracy": (sum(1 for r in all_results if r.get("is_correct")) / len(all_results)) if all_results and not use_gt_only else None,
"correct": sum(1 for r in all_results if r.get("is_correct")) if not use_gt_only else None,
"total": len(all_results),
"results": all_results,
}, f, indent=2)
if __name__ == "__main__":
Fire(main)

View File

@@ -29,7 +29,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.core.data_loader import DataLoader
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.queue import Queue
from sc.core.example_queue import Queue
from sc.core.agent_pool import AgentPool
from sc import debug_log
@@ -426,7 +426,7 @@ def main(config_path: str) -> None:
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
num_ctx=config.get("num_ctx", 15000),
max_new_tokens=config.get("max_new_tokens", 1024),
)
# Run experiments

View File

@@ -518,7 +518,7 @@ def main(
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
num_ctx=config.get("num_ctx", 15000),
max_new_tokens=config.get("max_new_tokens", 1024),
)
# Run experiment

View File

@@ -30,10 +30,12 @@ from fire import Fire
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
from sc.core.data_loader import DataLoader
from sc.core.scagent import SCAgent
from sc.core.model import load_models
from sc.core.queue import Queue
from sc.core.example_queue import Queue
from sc.core.agent_pool import AgentPool
from sc.logger import Logger
from sc import debug_log
log = debug_log.log
@@ -48,33 +50,33 @@ async def run_consistency_experiment(
) -> List[Dict[str, Any]]:
"""
Run consistency-based queue policy experiment.
Queue Update Policy:
- After each inference, calculate per-agent consistency
- Consistency = how many other agents agree with this agent's answer
- Rank queue elements by consistency score
- Evict the lowest consistency element
- Add a new random ICL example set
Args:
dataloader: ShuffledDataLoader with pre-shuffled test data
model_pool: Async model pool for LLM inference
config: Experiment configuration
user_id: User ID being processed
shuffle_seed: Shuffle seed for reproducibility
Returns:
List of result dictionaries
"""
# Set seeds for reproducibility within experiment
random.seed(shuffle_seed)
np.random.seed(shuffle_seed)
example_dataset = dataloader.get_examples()
if len(example_dataset) == 0:
log(f"[ERROR] No examples found for user {user_id}")
return []
# Build class_indices for Queue initialization
class_indices = {}
for idx, example in enumerate(example_dataset):
@@ -82,11 +84,11 @@ async def run_consistency_experiment(
if label not in class_indices:
class_indices[label] = []
class_indices[label].append(idx)
# Initialize Queue
queue_size = config.get("queue_size", 5)
ex_queue = Queue(class_indices, queue_size)
# Tracking variables
results = []
cumulative_correct = 0
@@ -94,10 +96,10 @@ async def run_consistency_experiment(
recent_results = []
confidence_history = []
consistency_history = []
all_predictions = []
all_ground_truths = []
debug_log.log_policy_experiment_header(
"CONSISTENCY-BASED",
user_id,
@@ -105,20 +107,22 @@ async def run_consistency_experiment(
len(dataloader),
queue_size,
)
for processed_count, sample in enumerate(dataloader):
ex_queue.set_current_time(processed_count)
# Log queue state before processing
debug_log.log_policy_queue_state("CONSISTENCY", processed_count, user_id, ex_queue)
debug_log.log_policy_queue_state(
"CONSISTENCY", processed_count, user_id, ex_queue
)
# Create agent pool
agent_pool = AgentPool(log_path=config["log_path"])
try:
for queue_idx, ex_idcs in enumerate(ex_queue):
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
agent = SCAgent(
name="EEG sensing",
index=queue_idx,
@@ -134,46 +138,50 @@ async def run_consistency_experiment(
except Exception as e:
log(f"[ERROR] Failed to create agents: {e}")
import traceback
traceback.print_exc()
continue
if len(agent_pool.agents) == 0:
log(f"[WARN] No agents created for sample {processed_count}")
continue
# Run parallel interpretation
try:
interpretation_result = await agent_pool.run_parallel_interpretation()
except Exception as e:
log(f"[ERROR] Interpretation failed: {e}")
import traceback
traceback.print_exc()
continue
if interpretation_result is None:
log(f"[WARN] Interpretation failed for sample {processed_count}")
continue
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
answer, queue_idcs, avg_confidence, consistency, responses = (
interpretation_result
)
# Evaluate result
ground_truth = sample["label"]
is_correct = answer == ground_truth
cumulative_correct += 1 if is_correct else 0
cumulative_accuracy = cumulative_correct / (processed_count + 1)
recent_results.append(1 if is_correct else 0)
if len(recent_results) > window_size:
recent_results.pop(0)
window_accuracy = sum(recent_results) / len(recent_results)
confidence_history.append(avg_confidence)
consistency_history.append(consistency)
all_predictions.append(answer)
all_ground_truths.append(ground_truth)
# Performance logging
debug_log.log_policy_result(
"CONSISTENCY",
@@ -188,28 +196,32 @@ async def run_consistency_experiment(
window_accuracy,
recent_results,
)
# CONSISTENCY-BASED Queue Update
if responses:
# Calculate per-agent consistency: how many other agents agree
all_answers = [r.get("ANSWER") for r in responses.values()]
consistency_map = {}
for idx, response in responses.items():
agent_answer = response.get("ANSWER")
# Consistency = ratio of agents (including self) that agree
agreement_count = all_answers.count(agent_answer)
agent_consistency = agreement_count / len(all_answers) if all_answers else 0
agent_consistency = (
agreement_count / len(all_answers) if all_answers else 0
)
consistency_map[idx] = agent_consistency
debug_log.log_consistency_map(responses, consistency_map)
# Update queue by consistency (reuses confidence method with consistency scores)
ex_queue.update_by_confidence(consistency_map)
ex_queue.increment_usage(list(responses.keys()))
debug_log.log_policy_queue_state_after("Consistency", processed_count, ex_queue)
debug_log.log_policy_queue_state_after(
"Consistency", processed_count, ex_queue
)
# Store result
result = {
"sample_idx": processed_count,
@@ -227,34 +239,36 @@ async def run_consistency_experiment(
"shuffle_seed": shuffle_seed,
}
results.append(result)
# Final statistics
survival_summary = ex_queue.get_survival_summary()
debug_log.log_final_policy_survival("usc", user_id, shuffle_seed, survival_summary)
if results:
results[-1]["queue_survival_stats"] = survival_summary
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
return results
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
def compute_statistics(
results: List[Dict[str, Any]], stages: List[str]
) -> Dict[str, Any]:
"""Compute comprehensive experiment statistics."""
if not results:
return {}
# Overall metrics
correct = sum(1 for r in results if r.get("is_correct", False))
total = len(results)
accuracy = correct / total if total > 0 else 0
# Confidence and consistency averages
confidences = [r.get("confidence", 0) for r in results]
consistencies = [r.get("consistency", 0) for r in results]
avg_confidence = np.mean(confidences) if confidences else 0
avg_consistency = np.mean(consistencies) if consistencies else 0
# Per-stage accuracy
stage_correct = {}
stage_total = {}
@@ -263,47 +277,59 @@ def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict
stage_total[gt] = stage_total.get(gt, 0) + 1
if r.get("is_correct", False):
stage_correct[gt] = stage_correct.get(gt, 0) + 1
stage_accuracy = {}
for stage in stages:
if stage in stage_total:
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
else:
stage_accuracy[stage] = 0.0
# F1 Score and Macro metrics
predictions = [r.get("answer", "") for r in results]
ground_truths = [r.get("ground_truth", "") for r in results]
try:
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
macro_f1 = f1_score(
ground_truths, predictions, average="macro", zero_division=0
)
macro_precision = precision_score(
ground_truths, predictions, average="macro", zero_division=0
)
macro_recall = recall_score(
ground_truths, predictions, average="macro", zero_division=0
)
except Exception:
macro_f1 = macro_precision = macro_recall = 0.0
# Temporal analysis
mid_point = len(results) // 2
if mid_point > 0:
first_half = results[:mid_point]
second_half = results[mid_point:]
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(
first_half
)
second_half_acc = sum(
1 for r in second_half if r.get("is_correct", False)
) / len(second_half)
improvement = second_half_acc - first_half_acc
else:
first_half_acc = second_half_acc = accuracy
improvement = 0
# Learning curve (every 10 samples)
learning_curve = []
for i in range(0, len(results), 10):
chunk = results[:i+10]
chunk = results[: i + 10]
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
learning_curve.append({
"sample_idx": min(i+10, len(results)),
"cumulative_accuracy": chunk_acc,
})
learning_curve.append(
{
"sample_idx": min(i + 10, len(results)),
"cumulative_accuracy": chunk_acc,
}
)
# Convergence speed (90% of final accuracy)
final_accuracy = accuracy
convergence_threshold = 0.9 * final_accuracy
@@ -315,14 +341,14 @@ def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict
if running_acc >= convergence_threshold:
convergence_idx = i
break
# Queue statistics
queue_stats = {}
for r in results:
if "queue_survival_stats" in r:
queue_stats = r["queue_survival_stats"]
break
return {
"experiment_type": "consistency",
"user_id": results[0].get("user_id") if results else None,
@@ -348,125 +374,38 @@ def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict
}
def save_results(
results: List[Dict[str, Any]],
stats: Dict[str, Any],
config: Dict[str, Any],
user_id: int,
shuffle_seed: int,
) -> None:
"""Save experiment results and statistics."""
log_path = config["log_path"]
# Create user/seed specific directory
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
os.makedirs(output_dir, exist_ok=True)
# Save statistics
stats_path = os.path.join(output_dir, "statistics.json")
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Statistics: {stats_path}")
# Save results
results_path = os.path.join(output_dir, "results.json")
with open(results_path, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
log(f"[SAVE] Results: {results_path}")
# Save config
config_path = os.path.join(output_dir, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
log(f"[SAVE] Config: {config_path}")
def main(
config_path: str = "sc/config/experiment_consistency.yaml",
user_id: int = 5,
shuffle_seed: int = 42,
) -> None:
def main(config_path: str) -> None:
"""
Run Consistency-based Queue Policy experiment.
Args:
config_path: Path to YAML configuration file
user_id: User ID to process (5 or 10)
shuffle_seed: Shuffle seed for data order (42, 123, or 456)
Example:
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
"""
log(f"[MAIN] Loading config: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
# Override with CLI arguments
config["user_id"] = user_id
config["shuffle_seed"] = shuffle_seed
# Create unique log path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
os.makedirs(config["log_path"], exist_ok=True)
# Print experiment info
debug_log.log_policy_main_header(
"CONSISTENCY-BASED",
user_id,
shuffle_seed,
config.get("queue_size", 5),
len(config.get("models", [])),
config["log_path"],
logger = Logger(config["log_path"])
logger.log_config(config)
dataloader = DataLoader(
config["data_path"],
config["user_id"],
shuffle=config["shuffle"],
seed=config["seed"],
)
# Load shuffled data
debug_log.log_policy_loading_data(user_id, shuffle_seed)
dataloader = ShuffledDataLoader(
data_path=config["data_path"],
user_id=user_id,
seed=shuffle_seed,
example_pool=config.get("example_pool", "out"),
)
# Load models
debug_log.log_policy_loading_models()
model_pool = load_models(
config["models"],
temperature=config.get("temperature", 0.0),
num_ctx=config.get("num_ctx", 15000),
temperature=config["temperature"],
max_new_tokens=config.get("max_new_tokens", 1024),
)
# Run experiment
debug_log.log_policy_start(label="experiment")
results = asyncio.run(run_consistency_experiment(
dataloader=dataloader,
model_pool=model_pool,
config=config,
user_id=user_id,
shuffle_seed=shuffle_seed,
))
# Compute statistics
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
stats = compute_statistics(results, stages)
# Print final summary
temporal = stats.get("temporal_analysis", {})
debug_log.log_policy_complete_summary(
"CONSISTENCY",
user_id,
shuffle_seed,
stats,
stats.get("stage_accuracy", {}),
stats.get("stage_sample_counts", {}),
temporal,
results = asyncio.run(
run_consistency_experiment(
dataloader=dataloader,
model_pool=model_pool,
config=config,
user_id=user_id,
shuffle_seed=shuffle_seed,
)
)
# Save results
save_results(results, stats, config, user_id, shuffle_seed)
log(f"\n[MAIN] Results saved to: {config['log_path']}")
if __name__ == "__main__":