Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
216bd2ecc3 | ||
|
|
df5e6db8b6 | ||
|
|
31f4af7106 | ||
|
|
8a26346b5e | ||
|
|
a7c8e43f89 |
198
.claude/CLAUDE.md
Normal file
198
.claude/CLAUDE.md
Normal file
@@ -0,0 +1,198 @@
|
||||
# Project Memory: tsllm_personalization_icl
|
||||
|
||||
## Project Overview
|
||||
|
||||
Research project on **personalized time-series classification using LLMs with In-Context Learning (ICL)**. The core idea is to use LLMs (via Ollama or HuggingFace) to classify sensor/physiological time-series data (e.g., EEG sleep stages, PPG blood pressure), where ICL examples are selected via different strategies (random vs. similarity-based) to study personalization.
|
||||
|
||||
Current branch: `case` (main branch: `main`)
|
||||
|
||||
---
|
||||
|
||||
## Repository Structure
|
||||
|
||||
```
|
||||
tsllm_personalization_icl/
|
||||
├── run.py # Main entry point (ICL experiment runner)
|
||||
├── config/
|
||||
│ └── sleepedf.yaml # Config: data path, models, selection criteria, log path
|
||||
├── core/ # Base pipeline (ICL experiments)
|
||||
│ ├── model.py # Model/AsyncModelPool (Ollama or LangChain init_chat_model)
|
||||
│ ├── agent.py # Base Agent class (memory, logging, JSON parsing)
|
||||
│ ├── sensing_agent.py # SensingAgent: ICL classify + evaluate + reflect
|
||||
│ ├── data_loader.py # DataLoader: test/example splits + similarity-based selection
|
||||
│ └── embedding_index.py # EmbeddingIndex: cosine similarity over Chronos-2 embeddings
|
||||
├── sc/ # Self-Consistency (SC) variant pipeline
|
||||
│ ├── run_sc.py # SC experiment runner (main entry: `python -m sc.run_sc config.yaml`)
|
||||
│ ├── run_confidence_based.py
|
||||
│ ├── run_consistency_based.py
|
||||
│ ├── run_sc_queue_random.py
|
||||
│ ├── run_usc.py
|
||||
│ ├── core/
|
||||
│ │ ├── model.py # HuggingFace CausalLM wrapper (returns text + logits)
|
||||
│ │ ├── agent.py # SC base agent
|
||||
│ │ ├── scagent.py # SCAgent: single-pass interpret() with REASON/CONFIDENCE/ANSWER
|
||||
│ │ ├── agent_pool.py # AgentPool: parallel interpret + majority voting
|
||||
│ │ ├── example_queue.py # Queue: priority queue of example sets, updated by confidence
|
||||
│ │ ├── data_loader.py # DataLoader + InMemoryDataLoader + PPGBPLoader
|
||||
│ │ ├── judge_agent.py # JudgeAgent
|
||||
│ │ ├── majority_voting.py # MajorityVoting utilities
|
||||
│ │ └── model_utils.py
|
||||
│ ├── analysis/
|
||||
│ │ └── analyze_sc_results.py
|
||||
│ ├── preprocess/
|
||||
│ │ └── shuffle_data.py
|
||||
│ ├── debug_log.py # Structured debug logging helpers
|
||||
│ ├── logger.py
|
||||
│ ├── hf_api.py
|
||||
│ └── ollama_test.py
|
||||
├── analysis/
|
||||
│ ├── ppgbp_loader.py
|
||||
│ ├── user_similarity/ # Embedding analysis scripts
|
||||
│ │ ├── chronos2/ # Chronos-2 embeddings for SleepEDF
|
||||
│ │ ├── chronos2_ppgbp/ # Chronos-2 embeddings for PPGBP
|
||||
│ │ ├── labram/ # LaBraM embeddings
|
||||
│ │ ├── sbert/
|
||||
│ │ ├── sbert_metadata/
|
||||
│ │ └── sbert_metadata_ppgbp/ # SBERT metadata embeddings for PPGBP
|
||||
│ ├── analyze_data.ipynb
|
||||
│ └── analyze_preliminary.ipynb
|
||||
├── preprocess/
|
||||
│ ├── dhedfreader.py
|
||||
│ └── preprocess_SleepEDF.py
|
||||
├── utils/
|
||||
│ ├── kill_ollamas.sh
|
||||
│ └── launch_ollamas.sh
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### Datasets
|
||||
- **SleepEDF**: EEG-based sleep stage classification (W, N1, N2, N3, REM).
|
||||
- **PPGBP**: PPG-based blood pressure prediction.
|
||||
- Data format: `<data_path>/<user_id>/1/` (train/examples, HuggingFace dataset on disk) and `<data_path>/<user_id>/2/` (test split). Metadata in `info.json` (keys: `task`, `class`, `feature`).
|
||||
|
||||
### ICL Selection Strategies (core/data_loader.py)
|
||||
- `out_random`: Random examples from OTHER users (cross-user, random)
|
||||
- `in_random`: Random examples from SAME user (personalized, random)
|
||||
- `out_similar`: Most similar examples from OTHER users (Chronos-2 embedding similarity)
|
||||
- `in_similar`: Most similar examples from SAME user (embedding similarity)
|
||||
|
||||
Similarity uses cosine similarity over Chronos-2 embeddings, one example per class (balanced).
|
||||
|
||||
### Models
|
||||
- **core/model.py**: Supports Ollama (local/remote via `ollama:url:<host>/<model>`) and any LangChain-supported model (OpenAI, Together, etc.)
|
||||
- **sc/core/model.py**: HuggingFace CausalLM loaded via `transformers`, returns text + logits. Used for SC experiments.
|
||||
|
||||
### Self-Consistency (sc/) Pipeline
|
||||
The SC pipeline uses a **priority queue of example sets**:
|
||||
- `Queue` (sc/core/example_queue.py): holds `capacity` example-sets (one set = one example per class). Updated each step by agent confidence scores — highest-confidence sets are kept, lowest evicted and replaced with a random new set.
|
||||
- `AgentPool` (sc/core/agent_pool.py): runs one `SCAgent` per queue slot in parallel, aggregates via majority vote, tracks confidence and consistency.
|
||||
- `SCAgent` (sc/core/scagent.py): single LLM call, returns `{REASON, CONFIDENCE, ANSWER}` JSON.
|
||||
|
||||
### Base Agent (core/agent.py)
|
||||
- Three memory tiers: `long_term_memory`, `short_term_memory`, `volatile_memory`
|
||||
- `invoke()`: async, appends messages to memory, calls model pool
|
||||
- `safe_parse_json()` / `safe_parse_json_list()`: robust JSON parsing with cleanup
|
||||
- Token counting via `tiktoken` (gpt-3.5-turbo encoding as proxy)
|
||||
|
||||
### SensingAgent (core/sensing_agent.py)
|
||||
Extends Agent. Methods:
|
||||
- `solve()`: classify a sample with ICL examples → `{REASON, ANSWER}`
|
||||
- `interpret()`: same as solve but without logging ground truth
|
||||
- `evaluate()`: evaluate another agent's answer (for multi-agent debate)
|
||||
- `reflect()`: refine answer based on peer evaluations
|
||||
|
||||
---
|
||||
|
||||
## Running Experiments
|
||||
|
||||
### Basic ICL run
|
||||
```bash
|
||||
python run.py run config/sleepedf.yaml
|
||||
```
|
||||
|
||||
### Compare multiple selection criteria
|
||||
```bash
|
||||
python run.py compare config/sleepedf.yaml \
|
||||
--criteria_list="out_random,in_random,out_similar,in_similar" \
|
||||
--embedding_path="./embeddings_full"
|
||||
```
|
||||
|
||||
### Self-Consistency run
|
||||
```bash
|
||||
python -m sc.run_sc sc/config/sleepedf_sc.yaml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Config (sleepedf.yaml) Key Fields
|
||||
- `data_path`: root of dataset (contains `info.json` and per-user dirs)
|
||||
- `log_path`: output directory for results and logs
|
||||
- `num_seeds`: number of random seeds to run
|
||||
- `num_examples`: ICL examples per class (default 1)
|
||||
- `selection_criteria`: `out_random` | `in_random` | `out_similar` | `in_similar`
|
||||
- `embedding_path`: path to pre-computed Chronos-2 embeddings (required for `*_similar`)
|
||||
- `models`: list of model specs (Ollama URLs or model IDs)
|
||||
|
||||
In SC configs, additional fields:
|
||||
- `queue_size`: number of example sets to maintain in the priority queue
|
||||
- `temperature`: LLM sampling temperature
|
||||
- `max_new_tokens`: max tokens for generation
|
||||
- `example_pool`: `"out"` (other users) or `"in"` (same user)
|
||||
- `continuous`: whether queue persists across test samples
|
||||
|
||||
---
|
||||
|
||||
## Data Format Details
|
||||
|
||||
Each HuggingFace dataset sample:
|
||||
```python
|
||||
{
|
||||
"user_id": str,
|
||||
"session_id": str, # "1" = train, "2" = test
|
||||
"idx": int,
|
||||
"label": str, # class name from info.json["class"]
|
||||
"features": dict, # str -> float/str (sensor features formatted for prompt)
|
||||
"data": dict, # optional raw data
|
||||
}
|
||||
```
|
||||
|
||||
`info.json`:
|
||||
```json
|
||||
{
|
||||
"task": "Sleep stage classification from EEG",
|
||||
"class": {"W": "Wake", "N1": "...", ...},
|
||||
"feature": "EEG channel description for the LLM..."
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## PPGBP Dataset (sc/core/data_loader.py)
|
||||
`PPGBPLoader`: reads xlsx metadata + signal `.txt` files from `0_subject/`. 80/20 train/test split at subject level. Metadata embeddings (SBERT) auto-generated and cached per subject as `.npy` files under `PPGBP_METADATA_EMBEDDINGS_ROOT` (set via env var or `.env`).
|
||||
|
||||
---
|
||||
|
||||
## Dependencies (requirements.txt)
|
||||
- `langchain`, `langchain_ollama`, `langchain_openai`, `langchain_together` — LLM backends
|
||||
- `datasets` — HuggingFace datasets (on-disk storage)
|
||||
- `chronos` — Chronos-2 time-series embeddings
|
||||
- `sentence_transformers` — SBERT for metadata embeddings
|
||||
- `transformers`, `torch` — HuggingFace model loading (SC pipeline)
|
||||
- `tiktoken` — token counting
|
||||
- `fire` — CLI argument parsing
|
||||
- `mne`, `neurokit2` — EEG/biosignal preprocessing
|
||||
- `numpy`, `pandas`, `scikit_learn`, `scipy`, `matplotlib`
|
||||
|
||||
---
|
||||
|
||||
## Notes / Patterns
|
||||
- Experiments sample every 10th test item (`if idx % 10 != 0: continue` in `run.py`).
|
||||
- SC runner currently hardcoded to first user only for testing (`users[:1]`).
|
||||
- Log structure: `<log_path>/<user_id>/<sample_idx>/<seed>/` with `summary.txt`, `log.txt`, `tokens.txt`.
|
||||
- Embedding index uses cosine similarity (L2-normalized dot product), filtered by user and session.
|
||||
- `DataLoader` in `sc/` is simpler (no similarity selection); similarity lives in `core/`.
|
||||
- `InMemoryDataLoader` and `prepare_dataset_for_sc()` allow integrating new datasets without writing to disk.
|
||||
0
.claude/plan.md
Normal file
0
.claude/plan.md
Normal file
235
analysis/ppgbp_loader.py
Normal file
235
analysis/ppgbp_loader.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
PPGBP Dataset Loader
|
||||
|
||||
A data loader/iterator for the PPG-BP dataset that reads xlsx metadata
|
||||
and corresponding signal text files, with optional batching support.
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Iterator, Tuple, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class PPGBPLoader:
|
||||
"""
|
||||
Data loader for the PPG-BP dataset.
|
||||
|
||||
Reads metadata from xlsx file and corresponding PPG signal text files.
|
||||
Supports iteration over single samples or batches.
|
||||
Subject list is built from 0_subject/ and matched with metadata;
|
||||
train/test is 80/20 at subject level (random, fixed by seed).
|
||||
"""
|
||||
|
||||
COLUMN_MAPPING = {
|
||||
"Sex(M/F)": "sex",
|
||||
"Age(year)": "age",
|
||||
"Systolic Blood Pressure(mmHg)": "sysbp",
|
||||
"Diastolic Blood Pressure(mmHg)": "diasbp",
|
||||
"Heart Rate(b/m)": "hr",
|
||||
"BMI(kg/m^2)": "bmi"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: str,
|
||||
split: str = 'all',
|
||||
batch_size: Optional[int] = None,
|
||||
shuffle: bool = False,
|
||||
num_segments: int = 3,
|
||||
seed: int = 42
|
||||
):
|
||||
self.base_dir = base_dir
|
||||
self.split = split
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.num_segments = num_segments
|
||||
self.seed = seed
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
self.signal_dir = os.path.join(base_dir, "0_subject")
|
||||
self.xlsx_path = self._find_xlsx_file()
|
||||
|
||||
self.metadata = self._load_metadata()
|
||||
self._all_subject_ids = self._get_unified_subject_ids()
|
||||
self._train_ids, self._test_ids = self._get_train_test_split(self._all_subject_ids)
|
||||
self.subject_ids = self._get_split_ids()
|
||||
self.metadata = self.metadata[self.metadata['subject_ID'].isin(self.subject_ids)]
|
||||
|
||||
self._current_idx = 0
|
||||
self._indices = np.arange(len(self.subject_ids))
|
||||
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(self._indices)
|
||||
|
||||
def _find_xlsx_file(self) -> str:
|
||||
for f in os.listdir(self.base_dir):
|
||||
if f.endswith('.xlsx'):
|
||||
return os.path.join(self.base_dir, f)
|
||||
raise FileNotFoundError(f"No xlsx file found in {self.base_dir}")
|
||||
|
||||
def _load_metadata(self) -> pd.DataFrame:
|
||||
# Use header=1: xlsx has a title row, then row with column names (subject_ID, etc.)
|
||||
df = pd.read_excel(self.xlsx_path, header=1)
|
||||
df = df.rename(columns=self.COLUMN_MAPPING)
|
||||
df = df.fillna(0)
|
||||
return df
|
||||
|
||||
def _get_unified_subject_ids(self) -> np.ndarray:
|
||||
"""Subject IDs present in both 0_subject/ (from listdir) and metadata."""
|
||||
if not os.path.isdir(self.signal_dir):
|
||||
return np.array([], dtype=np.int64)
|
||||
# Files are named {subject_id}_{segment}.txt
|
||||
ids_from_files = set()
|
||||
for f in os.listdir(self.signal_dir):
|
||||
if f.endswith(".txt"):
|
||||
try:
|
||||
sid = int(f.replace(".txt", "").split("_")[0])
|
||||
ids_from_files.add(sid)
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
meta_ids = set(self.metadata["subject_ID"].astype(int).values)
|
||||
unified = sorted(ids_from_files & meta_ids)
|
||||
return np.array(unified)
|
||||
|
||||
def _get_train_test_split(self, ids: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""80% train, 20% test at subject level; shuffle fixed by seed."""
|
||||
if len(ids) == 0:
|
||||
return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
|
||||
perm = self.rng.permutation(len(ids))
|
||||
ids_perm = ids[perm]
|
||||
n_train = max(1, int(round(0.8 * len(ids))))
|
||||
train_ids = ids_perm[:n_train]
|
||||
test_ids = ids_perm[n_train:]
|
||||
return train_ids, test_ids
|
||||
|
||||
def _get_split_ids(self) -> np.ndarray:
|
||||
if self.split == "train":
|
||||
return self._train_ids
|
||||
elif self.split == "test":
|
||||
return self._test_ids
|
||||
elif self.split == "all":
|
||||
return self._all_subject_ids
|
||||
else:
|
||||
raise ValueError(f"Unknown split: {self.split}. Use 'train', 'test', or 'all'.")
|
||||
|
||||
@property
|
||||
def train_ids(self) -> np.ndarray:
|
||||
"""Subject IDs in the train split (80%)."""
|
||||
return self._train_ids
|
||||
|
||||
@property
|
||||
def test_ids(self) -> np.ndarray:
|
||||
"""Subject IDs in the test split (20%)."""
|
||||
return self._test_ids
|
||||
|
||||
def _load_signal(self, subject_id: int) -> np.ndarray:
|
||||
segments = []
|
||||
for s in range(1, self.num_segments + 1):
|
||||
filepath = os.path.join(self.signal_dir, f"{subject_id}_{s}.txt")
|
||||
signal = pd.read_csv(filepath, sep='\t', header=None)
|
||||
signal = signal.values.squeeze()
|
||||
if len(signal) > 1:
|
||||
signal = signal[:-1]
|
||||
segments.append(signal)
|
||||
return np.array(segments, dtype=object)
|
||||
|
||||
def _get_metadata_dict(self, subject_id: int) -> Dict:
|
||||
row = self.metadata[self.metadata['subject_ID'] == subject_id].iloc[0]
|
||||
return row.to_dict()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.subject_ids)
|
||||
|
||||
def __iter__(self) -> 'PPGBPLoader':
|
||||
self._current_idx = 0
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(self._indices)
|
||||
return self
|
||||
|
||||
def __next__(self) -> Union[Tuple[Dict, np.ndarray], Tuple[List[Dict], List[np.ndarray]]]:
|
||||
if self._current_idx >= len(self.subject_ids):
|
||||
raise StopIteration
|
||||
|
||||
if self.batch_size is None:
|
||||
idx = self._indices[self._current_idx]
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata = self._get_metadata_dict(subject_id)
|
||||
signal = self._load_signal(subject_id)
|
||||
self._current_idx += 1
|
||||
return metadata, signal
|
||||
else:
|
||||
end_idx = min(self._current_idx + self.batch_size, len(self.subject_ids))
|
||||
batch_indices = self._indices[self._current_idx:end_idx]
|
||||
|
||||
metadata_list = []
|
||||
signal_list = []
|
||||
|
||||
for idx in batch_indices:
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata_list.append(self._get_metadata_dict(subject_id))
|
||||
signal_list.append(self._load_signal(subject_id))
|
||||
|
||||
self._current_idx = end_idx
|
||||
return metadata_list, signal_list
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[Dict, np.ndarray]:
|
||||
if idx < 0 or idx >= len(self.subject_ids):
|
||||
raise IndexError(f"Index {idx} out of range")
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata = self._get_metadata_dict(subject_id)
|
||||
signal = self._load_signal(subject_id)
|
||||
return metadata, signal
|
||||
|
||||
def get_by_subject_id(self, subject_id: int) -> Tuple[Dict, np.ndarray]:
|
||||
if subject_id not in self.subject_ids:
|
||||
raise ValueError(f"Subject ID {subject_id} not found in this split.")
|
||||
metadata = self._get_metadata_dict(subject_id)
|
||||
signal = self._load_signal(subject_id)
|
||||
return metadata, signal
|
||||
|
||||
def reset(self):
|
||||
self._current_idx = 0
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(self._indices)
|
||||
|
||||
def iter_batches(self, batch_size: int) -> Iterator[Tuple[List[Dict], List[np.ndarray]]]:
|
||||
indices = np.arange(len(self.subject_ids))
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(indices)
|
||||
|
||||
for start_idx in range(0, len(self.subject_ids), batch_size):
|
||||
end_idx = min(start_idx + batch_size, len(self.subject_ids))
|
||||
batch_indices = indices[start_idx:end_idx]
|
||||
|
||||
metadata_list = []
|
||||
signal_list = []
|
||||
|
||||
for idx in batch_indices:
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata_list.append(self._get_metadata_dict(subject_id))
|
||||
signal_list.append(self._load_signal(subject_id))
|
||||
|
||||
yield metadata_list, signal_list
|
||||
|
||||
|
||||
def get_loaders(
|
||||
base_dir: str,
|
||||
batch_size: Optional[int] = None,
|
||||
shuffle_train: bool = True,
|
||||
seed: int = 42
|
||||
) -> Tuple[PPGBPLoader, PPGBPLoader]:
|
||||
"""Convenience function to get train and test loaders (80/20 split)."""
|
||||
train_loader = PPGBPLoader(
|
||||
base_dir=base_dir, split="train",
|
||||
batch_size=batch_size, shuffle=shuffle_train, seed=seed
|
||||
)
|
||||
test_loader = PPGBPLoader(
|
||||
base_dir=base_dir, split="test",
|
||||
batch_size=batch_size, shuffle=False, seed=seed
|
||||
)
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("PPGBPLoader ready to use.")
|
||||
982
analysis/user_similarity/chronos2_ppgbp/gen_plot.py
Normal file
982
analysis/user_similarity/chronos2_ppgbp/gen_plot.py
Normal file
@@ -0,0 +1,982 @@
|
||||
"""
|
||||
Chronos-2 PPG-BP Embedding Extraction and Visualization Pipeline
|
||||
|
||||
This module provides functionality to:
|
||||
1. Extract embeddings from univariate PPG signals (PPG-BP dataset) using Chronos-2
|
||||
2. Visualize embeddings using t-SNE colored by various metadata categories
|
||||
|
||||
Chronos-2 Overview:
|
||||
Chronos-2 is a time series foundation model developed by Amazon that converts
|
||||
time series forecasting into a language modeling task. It tokenizes time series
|
||||
values and generates probabilistic forecasts using a transformer architecture.
|
||||
|
||||
Embedding Strategy:
|
||||
We use Chronos-2's internal encoder hidden states as embeddings. For univariate
|
||||
PPG signals, each signal is treated as a single-variate time series, producing
|
||||
a 512-dimensional embedding after pooling.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Load PPG signals via PPGBPLoader (3 segments per subject)
|
||||
2. Each segment is embedded independently through Chronos-2 encoder
|
||||
3. Encoder hidden states are pooled to produce a fixed-size embedding
|
||||
4. Metadata from the dataset xlsx is attached for visualization
|
||||
|
||||
Usage:
|
||||
# Extract embeddings
|
||||
python gen_plot.py extract --base_dir /path/to/ppgbp/DataFile --out_dir ./embeddings
|
||||
|
||||
# Visualize with t-SNE
|
||||
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
|
||||
|
||||
# LazyPredict: predict hypertension (or other target) from Chronos embeddings
|
||||
python gen_plot.py lazy_eval --emb_dir ./embeddings --target hypertension --out_csv ./lazy_hypertension.csv
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from datasets import Dataset, load_from_disk
|
||||
from chronos import Chronos2Pipeline
|
||||
from fire import Fire
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
try:
|
||||
from lazypredict.Supervised import LazyClassifier, LazyRegressor
|
||||
HAS_LAZYPREDICT = True
|
||||
except ImportError:
|
||||
HAS_LAZYPREDICT = False
|
||||
|
||||
# Add parent directories to path for importing PPGBPLoader
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
||||
from ppgbp_loader import PPGBPLoader
|
||||
|
||||
# Target column options for lazy_eval: classification vs regression
|
||||
LAZY_EVAL_CLASSIFICATION_TARGETS = [
|
||||
"hypertension",
|
||||
"sex",
|
||||
"age_category",
|
||||
"height_category",
|
||||
"weight_category",
|
||||
"systolic_category",
|
||||
"diastolic_category",
|
||||
"heart_rate_category",
|
||||
"bmi_category",
|
||||
]
|
||||
LAZY_EVAL_REGRESSION_TARGETS = [
|
||||
"age",
|
||||
"height",
|
||||
"weight",
|
||||
"systolic",
|
||||
"diastolic",
|
||||
"heart_rate",
|
||||
"bmi",
|
||||
]
|
||||
LAZY_EVAL_ALL_TARGETS = LAZY_EVAL_CLASSIFICATION_TARGETS + LAZY_EVAL_REGRESSION_TARGETS
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metadata Categorization Helpers
|
||||
# =============================================================================
|
||||
|
||||
def _categorize_age(age: Optional[int]) -> Optional[str]:
|
||||
if age is None:
|
||||
return None
|
||||
if age <= 35:
|
||||
return "Young (21-35)"
|
||||
elif age <= 55:
|
||||
return "Middle (36-55)"
|
||||
elif age <= 75:
|
||||
return "Senior (56-75)"
|
||||
else:
|
||||
return "Old (76+)"
|
||||
|
||||
|
||||
def _categorize_height(height: Optional[int]) -> Optional[str]:
|
||||
if height is None:
|
||||
return None
|
||||
if height <= 157:
|
||||
return "Short (145-157)"
|
||||
elif height <= 170:
|
||||
return "Below Avg (158-170)"
|
||||
elif height <= 183:
|
||||
return "Above Avg (171-183)"
|
||||
else:
|
||||
return "Tall (184+)"
|
||||
|
||||
|
||||
def _categorize_weight(weight: Optional[int]) -> Optional[str]:
|
||||
if weight is None:
|
||||
return None
|
||||
if weight <= 52:
|
||||
return "Light (36-52)"
|
||||
elif weight <= 69:
|
||||
return "Below Avg (53-69)"
|
||||
elif weight <= 86:
|
||||
return "Above Avg (70-86)"
|
||||
else:
|
||||
return "Heavy (87+)"
|
||||
|
||||
|
||||
def _categorize_systolic(sbp: Optional[int]) -> Optional[str]:
|
||||
if sbp is None:
|
||||
return None
|
||||
if sbp < 90:
|
||||
return "Low (<90)"
|
||||
elif sbp < 120:
|
||||
return "Normal (90-119)"
|
||||
elif sbp < 140:
|
||||
return "Elevated (120-139)"
|
||||
else:
|
||||
return "High (140+)"
|
||||
|
||||
|
||||
def _categorize_diastolic(dbp: Optional[int]) -> Optional[str]:
|
||||
if dbp is None:
|
||||
return None
|
||||
if dbp < 60:
|
||||
return "Low (<60)"
|
||||
elif dbp < 80:
|
||||
return "Normal (60-79)"
|
||||
elif dbp < 90:
|
||||
return "Elevated (80-89)"
|
||||
else:
|
||||
return "High (90+)"
|
||||
|
||||
|
||||
def _categorize_heart_rate(hr: Optional[int]) -> Optional[str]:
|
||||
if hr is None:
|
||||
return None
|
||||
if hr <= 65:
|
||||
return "Low (52-65)"
|
||||
elif hr <= 79:
|
||||
return "Normal (66-79)"
|
||||
elif hr <= 93:
|
||||
return "Elevated (80-93)"
|
||||
else:
|
||||
return "High (94+)"
|
||||
|
||||
|
||||
def _categorize_bmi(bmi: Optional[float]) -> Optional[str]:
|
||||
if bmi is None:
|
||||
return None
|
||||
if bmi < 18.5:
|
||||
return "Underweight (<18.5)"
|
||||
elif bmi < 25:
|
||||
return "Normal (18.5-24.9)"
|
||||
elif bmi < 30:
|
||||
return "Overweight (25-29.9)"
|
||||
else:
|
||||
return "Obese (30+)"
|
||||
|
||||
|
||||
def _safe_int(val, sentinel=0):
|
||||
"""Convert a value to int, returning None if it equals the sentinel or is invalid."""
|
||||
if val is None or (isinstance(val, float) and pd.isna(val)):
|
||||
return None
|
||||
try:
|
||||
v = int(val)
|
||||
return None if v == sentinel else v
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _safe_float(val, sentinel=0.0):
|
||||
"""Convert a value to float, returning None if it equals the sentinel or is invalid."""
|
||||
if val is None or (isinstance(val, float) and pd.isna(val)):
|
||||
return None
|
||||
try:
|
||||
v = float(val)
|
||||
return None if v == sentinel else v
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_metadata_fields(metadata: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract and sanitize metadata fields from a PPGBPLoader metadata dict.
|
||||
|
||||
PPGBPLoader renames some columns (sex, age, sysbp, diasbp, hr, bmi) but
|
||||
leaves others with original names (Height(cm), Weight(kg), Hypertension).
|
||||
fillna(0) is applied, so 0 may represent missing values for some fields.
|
||||
|
||||
Returns a dict with standardized keys.
|
||||
"""
|
||||
# Sex: "M"/"F" string, or 0 if missing (from fillna)
|
||||
sex = metadata.get("sex", None)
|
||||
if sex is not None and sex != 0 and str(sex).strip() in ("M", "F"):
|
||||
sex = str(sex).strip()
|
||||
else:
|
||||
sex = None
|
||||
|
||||
age = _safe_int(metadata.get("age", None))
|
||||
height = _safe_int(metadata.get("Height(cm)", None))
|
||||
weight = _safe_int(metadata.get("Weight(kg)", None))
|
||||
sysbp = _safe_int(metadata.get("sysbp", None))
|
||||
diasbp = _safe_int(metadata.get("diasbp", None))
|
||||
hr = _safe_int(metadata.get("hr", None))
|
||||
bmi = _safe_float(metadata.get("bmi", None))
|
||||
|
||||
# Hypertension: may be string (e.g. "Stage 2 hypertension") or numeric; keep as string
|
||||
hypertension = metadata.get("Hypertension", None)
|
||||
if hypertension is not None and not (isinstance(hypertension, float) and pd.isna(hypertension)):
|
||||
hypertension = str(hypertension).strip()
|
||||
else:
|
||||
hypertension = None
|
||||
|
||||
return {
|
||||
"sex": sex,
|
||||
"age": age,
|
||||
"height": height,
|
||||
"weight": weight,
|
||||
"sysbp": sysbp,
|
||||
"diasbp": diasbp,
|
||||
"hr": hr,
|
||||
"bmi": bmi,
|
||||
"hypertension": hypertension,
|
||||
"age_category": _categorize_age(age),
|
||||
"height_category": _categorize_height(height),
|
||||
"weight_category": _categorize_weight(weight),
|
||||
"systolic_category": _categorize_systolic(sysbp),
|
||||
"diastolic_category": _categorize_diastolic(diasbp),
|
||||
"heart_rate_category": _categorize_heart_rate(hr),
|
||||
"bmi_category": _categorize_bmi(bmi),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Extractor Class
|
||||
# =============================================================================
|
||||
|
||||
class Chronos_2_Embedder:
|
||||
"""
|
||||
Extracts fixed-dimensional embeddings from univariate PPG signals using Chronos-2.
|
||||
|
||||
Architecture:
|
||||
Input PPG Signal -> Patching -> Encoder (6 layers) -> Hidden States -> Pooling -> Embedding
|
||||
|
||||
For univariate PPG:
|
||||
Input shape: (1, 1, L) -- batch=1, variates=1, length=L
|
||||
Output shape: (512,) -- d_model=512 after pooling
|
||||
|
||||
Pooling Strategies:
|
||||
- mean: Average across all patch tokens (excluding special tokens)
|
||||
- cls: Use the [REG] token embedding (similar to BERT's [CLS])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "amazon/chronos-2",
|
||||
device_map: str = None,
|
||||
pooling_strategy: str = "mean",
|
||||
):
|
||||
self.pooling_strategy = pooling_strategy
|
||||
|
||||
if device_map is None:
|
||||
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device_map = device_map
|
||||
self.device = torch.device(
|
||||
"cuda" if device_map.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
print(f"[INFO] Loading Chronos-2 model: {model_name}")
|
||||
print(f"[INFO] Device: {device_map}")
|
||||
print(f"[INFO] Pooling strategy: {pooling_strategy}")
|
||||
self.pipeline: Chronos2Pipeline = Chronos2Pipeline.from_pretrained(
|
||||
model_name,
|
||||
device_map=device_map,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, signal: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a single univariate PPG signal.
|
||||
|
||||
Args:
|
||||
signal: 1D numpy array of PPG signal values (variable length OK)
|
||||
|
||||
Returns:
|
||||
Embedding vector of shape (d_model,) = (512,)
|
||||
"""
|
||||
# Reshape to (B=1, V=1, L) for Chronos-2 input
|
||||
x_input = signal.reshape(1, 1, -1).astype(np.float32)
|
||||
|
||||
# Get encoder hidden states
|
||||
# Returns: list of tensors (one per batch item), each (V, num_patches+2, d_model)
|
||||
embeddings_list, _ = self.pipeline.embed(x_input)
|
||||
emb = embeddings_list[0] # (1, num_patches+2, 512) for V=1
|
||||
|
||||
if self.pooling_strategy == "cls":
|
||||
# Use the [REG] token (first token)
|
||||
pooled = emb[0, 0, :] # (d_model,)
|
||||
else: # "mean" pooling (default)
|
||||
# Average across all patch tokens, skip [REG] (first) and masked patch (last)
|
||||
pooled = emb[0, 1:-1, :].mean(dim=0) # (d_model,)
|
||||
|
||||
return pooled.cpu().numpy().astype(np.float32)
|
||||
|
||||
def extract_embeddings(
|
||||
self,
|
||||
base_dir: str,
|
||||
split: str = "all",
|
||||
) -> Dataset:
|
||||
"""
|
||||
Extract embeddings from all subjects/segments in the PPG-BP dataset.
|
||||
|
||||
Each subject has 3 PPG segments. Each segment is embedded independently
|
||||
via Chronos-2, resulting in 3 data points per subject sharing the same metadata.
|
||||
|
||||
Args:
|
||||
base_dir: Root directory of PPG-BP dataset (containing 0_subject/ and .xlsx)
|
||||
split: Data split ('train', 'val', 'test', 'all')
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with columns:
|
||||
- subject_id, segment_id (identifiers)
|
||||
- sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension
|
||||
- age_category, height_category, weight_category, systolic_category, etc.
|
||||
- embedding (512-dim vector)
|
||||
"""
|
||||
loader = PPGBPLoader(base_dir=base_dir, split="all")
|
||||
n_subjects = len(loader)
|
||||
train_id_set = set(int(x) for x in loader.train_ids)
|
||||
test_id_set = set(int(x) for x in loader.test_ids)
|
||||
print(f"[INFO] Loaded {n_subjects} subjects (train={len(train_id_set)}, test={len(test_id_set)})")
|
||||
|
||||
# Accumulator lists
|
||||
all_embeddings = []
|
||||
all_subject_ids = []
|
||||
all_segment_ids = []
|
||||
all_splits = []
|
||||
all_sexes = []
|
||||
all_ages = []
|
||||
all_heights = []
|
||||
all_weights = []
|
||||
all_sysbps = []
|
||||
all_diasbps = []
|
||||
all_hrs = []
|
||||
all_bmis = []
|
||||
all_hypertensions = []
|
||||
all_age_categories = []
|
||||
all_height_categories = []
|
||||
all_weight_categories = []
|
||||
all_systolic_categories = []
|
||||
all_diastolic_categories = []
|
||||
all_hr_categories = []
|
||||
all_bmi_categories = []
|
||||
|
||||
total_segments = 0
|
||||
|
||||
for subj_idx in range(n_subjects):
|
||||
metadata, signals = loader[subj_idx]
|
||||
subject_id_int = int(metadata.get("subject_ID", 0))
|
||||
subject_id = str(subject_id_int)
|
||||
split_label = "train" if subject_id_int in train_id_set else "test"
|
||||
|
||||
# Extract and categorize metadata once per subject
|
||||
fields = _extract_metadata_fields(metadata)
|
||||
|
||||
# Process each PPG segment
|
||||
for seg_idx, signal in enumerate(signals):
|
||||
emb = self.compute_embedding(signal)
|
||||
|
||||
all_embeddings.append(emb.tolist())
|
||||
all_subject_ids.append(subject_id)
|
||||
all_segment_ids.append(seg_idx + 1)
|
||||
all_splits.append(split_label)
|
||||
all_sexes.append(fields["sex"])
|
||||
all_ages.append(fields["age"])
|
||||
all_heights.append(fields["height"])
|
||||
all_weights.append(fields["weight"])
|
||||
all_sysbps.append(fields["sysbp"])
|
||||
all_diasbps.append(fields["diasbp"])
|
||||
all_hrs.append(fields["hr"])
|
||||
all_bmis.append(fields["bmi"])
|
||||
all_hypertensions.append(fields["hypertension"])
|
||||
all_age_categories.append(fields["age_category"])
|
||||
all_height_categories.append(fields["height_category"])
|
||||
all_weight_categories.append(fields["weight_category"])
|
||||
all_systolic_categories.append(fields["systolic_category"])
|
||||
all_diastolic_categories.append(fields["diastolic_category"])
|
||||
all_hr_categories.append(fields["heart_rate_category"])
|
||||
all_bmi_categories.append(fields["bmi_category"])
|
||||
|
||||
total_segments += 1
|
||||
|
||||
if (subj_idx + 1) % 20 == 0 or subj_idx == n_subjects - 1:
|
||||
print(
|
||||
f"[INFO] Processed {subj_idx + 1}/{n_subjects} subjects "
|
||||
f"({total_segments} segments)"
|
||||
)
|
||||
|
||||
# Create HuggingFace Dataset
|
||||
result_dataset = Dataset.from_dict({
|
||||
"subject_id": all_subject_ids,
|
||||
"segment_id": all_segment_ids,
|
||||
"split": all_splits,
|
||||
"sex": all_sexes,
|
||||
"age": all_ages,
|
||||
"age_category": all_age_categories,
|
||||
"height": all_heights,
|
||||
"height_category": all_height_categories,
|
||||
"weight": all_weights,
|
||||
"weight_category": all_weight_categories,
|
||||
"systolic": all_sysbps,
|
||||
"systolic_category": all_systolic_categories,
|
||||
"diastolic": all_diasbps,
|
||||
"diastolic_category": all_diastolic_categories,
|
||||
"heart_rate": all_hrs,
|
||||
"heart_rate_category": all_hr_categories,
|
||||
"bmi": all_bmis,
|
||||
"bmi_category": all_bmi_categories,
|
||||
"hypertension": all_hypertensions,
|
||||
"embedding": all_embeddings,
|
||||
})
|
||||
|
||||
print(f"[INFO] Total segments: {len(result_dataset)}")
|
||||
if len(result_dataset) > 0:
|
||||
print(f"[INFO] Embedding dim: {len(result_dataset[0]['embedding'])}")
|
||||
|
||||
return result_dataset
|
||||
|
||||
def save_embeddings(self, dataset: Dataset, output_dir: str) -> None:
|
||||
"""Save embeddings dataset to disk in HuggingFace format."""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
dataset.save_to_disk(output_dir)
|
||||
print(f"[DONE] Saved embeddings dataset: {output_dir}")
|
||||
if len(dataset) > 0:
|
||||
print(
|
||||
f"[DONE] Total samples: {len(dataset)}, "
|
||||
f"Embedding dim: {len(dataset[0]['embedding'])}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||
"""Load saved embeddings dataset from disk."""
|
||||
dataset = load_from_disk(embedding_dir)
|
||||
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization Functions
|
||||
# =============================================================================
|
||||
|
||||
def reduce_to_2d_tsne(
|
||||
embeddings: np.ndarray,
|
||||
perplexity: float = 30.0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Reduce high-dimensional embeddings to 2D using t-SNE.
|
||||
|
||||
Args:
|
||||
embeddings: Array of shape (num_samples, embedding_dim)
|
||||
perplexity: t-SNE perplexity (typically 5-50). Rule of thumb: ~ sqrt(N)
|
||||
|
||||
Returns:
|
||||
2D coordinates of shape (num_samples, 2)
|
||||
"""
|
||||
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
|
||||
|
||||
tsne = TSNE(
|
||||
n_components=2,
|
||||
random_state=0,
|
||||
perplexity=perplexity,
|
||||
max_iter=1000,
|
||||
init="random",
|
||||
learning_rate="auto",
|
||||
)
|
||||
|
||||
return tsne.fit_transform(embeddings)
|
||||
|
||||
|
||||
def create_scatter_plot_continuous(
|
||||
coordinates: np.ndarray,
|
||||
values: np.ndarray,
|
||||
title: str,
|
||||
output_path: str,
|
||||
label: str = "Value",
|
||||
cmap: str = "viridis",
|
||||
) -> None:
|
||||
"""
|
||||
Create a 2D scatter plot colored by a continuous variable.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
values: Values for each point (may contain None)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF)
|
||||
label: Colorbar label
|
||||
cmap: Matplotlib colormap name
|
||||
"""
|
||||
values_float = np.array([float(v) if v is not None else np.nan for v in values])
|
||||
valid_mask = ~np.isnan(values_float)
|
||||
valid_coords = coordinates[valid_mask]
|
||||
valid_values = values_float[valid_mask]
|
||||
|
||||
if len(valid_values) == 0:
|
||||
print(f"[WARN] No valid {label} data found. Skipping plot: {output_path}")
|
||||
return
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
scatter = ax.scatter(
|
||||
valid_coords[:, 0],
|
||||
valid_coords[:, 1],
|
||||
c=valid_values,
|
||||
cmap=cmap,
|
||||
alpha=0.7,
|
||||
s=50,
|
||||
)
|
||||
cbar = plt.colorbar(scatter, ax=ax)
|
||||
cbar.set_label(label)
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format="pdf", bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
print(f"[INFO] {label} range: {valid_values.min():.1f} - {valid_values.max():.1f}")
|
||||
|
||||
|
||||
def create_scatter_plot_categorical(
|
||||
coordinates: np.ndarray,
|
||||
categories: np.ndarray,
|
||||
title: str,
|
||||
output_path: str,
|
||||
label: str = "Category",
|
||||
) -> None:
|
||||
"""
|
||||
Create a 2D scatter plot colored by a categorical variable.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
categories: Category values for each point (may contain None)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF)
|
||||
label: Legend title
|
||||
"""
|
||||
valid_mask = np.array([c is not None and str(c) != "None" for c in categories])
|
||||
valid_coords = coordinates[valid_mask]
|
||||
valid_cats = np.array(categories)[valid_mask]
|
||||
|
||||
if len(valid_cats) == 0:
|
||||
print(f"[WARN] No valid {label} data found. Skipping plot: {output_path}")
|
||||
return
|
||||
|
||||
unique_cats = sorted(set(str(c) for c in valid_cats))
|
||||
colormap = plt.cm.tab10 if len(unique_cats) <= 10 else plt.cm.tab20
|
||||
colors = {cat: colormap(i % 20) for i, cat in enumerate(unique_cats)}
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
for cat in unique_cats:
|
||||
mask = np.array([str(c) == cat for c in valid_cats])
|
||||
ax.scatter(
|
||||
valid_coords[mask, 0],
|
||||
valid_coords[mask, 1],
|
||||
c=[colors[cat]],
|
||||
label=cat,
|
||||
alpha=0.7,
|
||||
s=50,
|
||||
)
|
||||
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
ax.legend(title=label, loc="best", markerscale=2)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format="pdf", bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
print(f"[INFO] {label} categories: {unique_cats}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LazyPredict multi-seed helper
|
||||
# =============================================================================
|
||||
|
||||
def _run_lazy_one_split(
|
||||
X: np.ndarray,
|
||||
y_raw: List[Any],
|
||||
subject_ids_valid: List[int],
|
||||
train_id_set: set,
|
||||
test_id_set: set,
|
||||
is_classification: bool,
|
||||
verbose: int = 0,
|
||||
ignore_warnings: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
"""Run LazyClassifier or LazyRegressor for one train/test split; return metrics DataFrame."""
|
||||
train_mask = np.array([int(s) in train_id_set for s in subject_ids_valid])
|
||||
test_mask = np.array([int(s) in test_id_set for s in subject_ids_valid])
|
||||
if train_mask.sum() == 0 or test_mask.sum() == 0:
|
||||
raise ValueError("Empty train or test set for this split.")
|
||||
X_train, X_test = X[train_mask], X[test_mask]
|
||||
y_train_raw = [y_raw[i] for i in range(len(y_raw)) if train_mask[i]]
|
||||
y_test_raw = [y_raw[i] for i in range(len(y_raw)) if test_mask[i]]
|
||||
|
||||
if is_classification:
|
||||
le = LabelEncoder()
|
||||
y_train = le.fit_transform([str(v) for v in y_train_raw])
|
||||
y_test = le.transform([str(v) for v in y_test_raw])
|
||||
clf = LazyClassifier(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
|
||||
models, _ = clf.fit(X_train, X_test, y_train, y_test)
|
||||
else:
|
||||
y_train = np.array([float(v) for v in y_train_raw], dtype=np.float64)
|
||||
y_test = np.array([float(v) for v in y_test_raw], dtype=np.float64)
|
||||
reg = LazyRegressor(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
|
||||
models, _ = reg.fit(X_train, X_test, y_train, y_test)
|
||||
# Ensure model name is index for aggregation across seeds
|
||||
if "Model" in models.columns:
|
||||
models = models.set_index("Model")
|
||||
return models
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
class CLI:
|
||||
"""
|
||||
CLI for Chronos-2 PPG-BP embedding extraction and t-SNE visualization.
|
||||
|
||||
Commands:
|
||||
extract - Compute Chronos-2 embeddings from PPG signals
|
||||
plot - Generate t-SNE scatter plots colored by metadata
|
||||
lazy_eval - LazyPredict: X=embedding, y=target (default hypertension); single split
|
||||
lazy_eval_multi_seed - LazyPredict over 10 seeds for all targets; save {target}.csv (mean +/- std)
|
||||
"""
|
||||
|
||||
def extract(
|
||||
self,
|
||||
base_dir: str,
|
||||
out_dir: str,
|
||||
model: str = "amazon/chronos-2",
|
||||
pooling: str = "mean",
|
||||
split: str = "all",
|
||||
) -> None:
|
||||
"""
|
||||
Extract Chronos-2 embeddings from PPG-BP signals.
|
||||
|
||||
Args:
|
||||
base_dir: PPG-BP dataset root (contains 0_subject/ folder and .xlsx)
|
||||
out_dir: Output directory for the HuggingFace embeddings dataset
|
||||
model: Chronos-2 model name (default: amazon/chronos-2)
|
||||
pooling: Pooling strategy - 'mean' or 'cls' (default: mean)
|
||||
split: Data split - 'train', 'val', 'test', 'all' (default: all)
|
||||
"""
|
||||
if pooling not in ("mean", "cls"):
|
||||
raise ValueError(f"Invalid pooling: {pooling}. Use 'mean' or 'cls'.")
|
||||
|
||||
embedder = Chronos_2_Embedder(
|
||||
model_name=model,
|
||||
pooling_strategy=pooling,
|
||||
)
|
||||
dataset = embedder.extract_embeddings(base_dir, "all")
|
||||
embedder.save_embeddings(dataset, out_dir)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
emb_dir: str,
|
||||
out_dir: str,
|
||||
perplexity: float = 30.0,
|
||||
subjects: str = None,
|
||||
num_subjects: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize embeddings with t-SNE, colored by each metadata column.
|
||||
Generates plots for (1) train, (2) test, (3) train+test, each with all categories.
|
||||
|
||||
Args:
|
||||
emb_dir: Directory containing the saved HuggingFace embeddings dataset
|
||||
out_dir: Output directory for PDF plots (subdirs: train/, test/, all/)
|
||||
perplexity: t-SNE perplexity parameter (default: 30.0)
|
||||
subjects: Comma-separated subject IDs to include (e.g., '2,6,8')
|
||||
num_subjects: Include only first N subjects, 0 = all (default: 0)
|
||||
"""
|
||||
# Load saved embeddings (must have "split" column)
|
||||
full_dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
|
||||
|
||||
if "split" not in full_dataset.column_names:
|
||||
raise ValueError(
|
||||
"Embeddings dataset missing 'split' column. Re-run extract with updated PPGBPLoader."
|
||||
)
|
||||
|
||||
# Optional subject filtering
|
||||
if subjects:
|
||||
subject_list = [s.strip() for s in subjects.split(",")]
|
||||
full_dataset = full_dataset.filter(lambda x: x["subject_id"] in subject_list)
|
||||
print(f"[INFO] Filtered to subjects: {subject_list}")
|
||||
elif num_subjects > 0:
|
||||
all_subjects = sorted(set(full_dataset["subject_id"]))
|
||||
selected = all_subjects[:num_subjects]
|
||||
full_dataset = full_dataset.filter(lambda x: x["subject_id"] in selected)
|
||||
print(f"[INFO] Selected first {num_subjects} subjects: {selected}")
|
||||
|
||||
continuous_vars = [
|
||||
("age", "Age (years)", "viridis"),
|
||||
("height", "Height (cm)", "plasma"),
|
||||
("weight", "Weight (kg)", "cividis"),
|
||||
("systolic", "Systolic BP (mmHg)", "Reds"),
|
||||
("diastolic", "Diastolic BP (mmHg)", "Blues"),
|
||||
("heart_rate", "Heart Rate (bpm)", "Oranges"),
|
||||
("bmi", "BMI (kg/m²)", "Greens"),
|
||||
]
|
||||
categorical_vars = [
|
||||
("subject_id", "Subject ID"),
|
||||
("sex", "Sex"),
|
||||
("hypertension", "Hypertension"),
|
||||
("age_category", "Age Category"),
|
||||
("height_category", "Height Category"),
|
||||
("weight_category", "Weight Category"),
|
||||
("systolic_category", "Systolic BP Category"),
|
||||
("diastolic_category", "Diastolic BP Category"),
|
||||
("heart_rate_category", "Heart Rate Category"),
|
||||
("bmi_category", "BMI Category"),
|
||||
]
|
||||
|
||||
for split_name in ["train", "test", "all"]:
|
||||
if split_name == "all":
|
||||
dataset = full_dataset
|
||||
else:
|
||||
dataset = full_dataset.filter(lambda x: x["split"] == split_name)
|
||||
n = len(dataset)
|
||||
if n == 0:
|
||||
print(f"[WARN] No samples for split '{split_name}', skipping.")
|
||||
continue
|
||||
split_out_dir = os.path.join(out_dir, split_name)
|
||||
os.makedirs(split_out_dir, exist_ok=True)
|
||||
print(f"[INFO] t-SNE for split '{split_name}': {n} samples -> {split_out_dir}")
|
||||
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||
|
||||
for var_name, var_label, cmap in continuous_vars:
|
||||
if var_name in dataset.column_names:
|
||||
values = dataset[var_name]
|
||||
create_scatter_plot_continuous(
|
||||
coordinates_2d,
|
||||
values,
|
||||
f"Chronos-2 t-SNE [{split_name}] (Colored by {var_label})",
|
||||
os.path.join(split_out_dir, f"tsne_by_{var_name}.pdf"),
|
||||
label=var_label,
|
||||
cmap=cmap,
|
||||
)
|
||||
|
||||
for var_name, var_label in categorical_vars:
|
||||
if var_name in dataset.column_names:
|
||||
categories = dataset[var_name]
|
||||
create_scatter_plot_categorical(
|
||||
coordinates_2d,
|
||||
categories,
|
||||
f"Chronos-2 t-SNE [{split_name}] (Colored by {var_label})",
|
||||
os.path.join(split_out_dir, f"tsne_by_{var_name}.pdf"),
|
||||
label=var_label,
|
||||
)
|
||||
|
||||
print(f"\n[DONE] All plots saved under: {out_dir} (train/, test/, all/)")
|
||||
|
||||
def lazy_eval(
|
||||
self,
|
||||
emb_dir: str,
|
||||
target: str = "hypertension",
|
||||
out_csv: str = None,
|
||||
verbose: int = 0,
|
||||
ignore_warnings: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Run LazyPredict (many classifiers or regressors) on Chronos embeddings (X) vs target (y).
|
||||
Uses the same train/test split as PPGBPLoader (split column in the embeddings dataset).
|
||||
|
||||
Args:
|
||||
emb_dir: Directory containing the saved HuggingFace embeddings dataset
|
||||
target: Target column for y. Default 'hypertension'.
|
||||
Classification: hypertension, sex, age_category, height_category,
|
||||
weight_category, systolic_category, diastolic_category,
|
||||
heart_rate_category, bmi_category
|
||||
Regression: age, height, weight, systolic, diastolic, heart_rate, bmi
|
||||
out_csv: If set, save the models comparison DataFrame to this path (e.g. lazy_hypertension.csv)
|
||||
verbose: LazyPredict verbose (0 or 1)
|
||||
ignore_warnings: LazyPredict ignore_warnings
|
||||
"""
|
||||
if not HAS_LAZYPREDICT:
|
||||
raise ImportError("lazypredict is required. Install with: pip install lazypredict")
|
||||
|
||||
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
|
||||
if "split" not in dataset.column_names:
|
||||
raise ValueError(
|
||||
"Embeddings dataset missing 'split' column. Re-run extract with updated PPGBPLoader."
|
||||
)
|
||||
if target not in dataset.column_names:
|
||||
raise ValueError(
|
||||
f"Target '{target}' not in dataset. Available: {dataset.column_names}"
|
||||
)
|
||||
|
||||
X = np.array(dataset["embedding"], dtype=np.float64)
|
||||
y_raw = list(dataset[target])
|
||||
|
||||
# Drop rows where target is missing
|
||||
valid = [
|
||||
i for i, v in enumerate(y_raw)
|
||||
if v is not None and not (isinstance(v, float) and pd.isna(v)) and str(v).strip() != ""
|
||||
]
|
||||
if len(valid) == 0:
|
||||
raise ValueError(f"No valid non-missing values for target '{target}'.")
|
||||
X = X[valid]
|
||||
y_raw = [y_raw[i] for i in valid]
|
||||
splits = [dataset["split"][i] for i in valid]
|
||||
|
||||
train_mask = np.array([s == "train" for s in splits])
|
||||
test_mask = np.array([s == "test" for s in splits])
|
||||
X_train, X_test = X[train_mask], X[test_mask]
|
||||
y_train_raw = [y_raw[i] for i in range(len(y_raw)) if splits[i] == "train"]
|
||||
y_test_raw = [y_raw[i] for i in range(len(y_raw)) if splits[i] == "test"]
|
||||
|
||||
is_classification = target in LAZY_EVAL_CLASSIFICATION_TARGETS
|
||||
if target in LAZY_EVAL_REGRESSION_TARGETS and not is_classification:
|
||||
is_classification = False
|
||||
elif target not in LAZY_EVAL_REGRESSION_TARGETS:
|
||||
is_classification = True
|
||||
|
||||
if is_classification:
|
||||
le = LabelEncoder()
|
||||
y_train = le.fit_transform([str(v) for v in y_train_raw])
|
||||
y_test = le.transform([str(v) for v in y_test_raw])
|
||||
print(f"[INFO] LazyClassifier: X=embedding, y={target} (classes: {list(le.classes_)})")
|
||||
print(f"[INFO] Train: {len(y_train)}, Test: {len(y_test)}")
|
||||
clf = LazyClassifier(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
|
||||
models, predictions = clf.fit(X_train, X_test, y_train, y_test)
|
||||
else:
|
||||
y_train = np.array([float(v) for v in y_train_raw], dtype=np.float64)
|
||||
y_test = np.array([float(v) for v in y_test_raw], dtype=np.float64)
|
||||
print(f"[INFO] LazyRegressor: X=embedding, y={target}")
|
||||
print(f"[INFO] Train: {len(y_train)}, Test: {len(y_test)}")
|
||||
reg = LazyRegressor(verbose=verbose, ignore_warnings=ignore_warnings, custom_metric=None)
|
||||
models, predictions = reg.fit(X_train, X_test, y_train, y_test)
|
||||
|
||||
print("\n[RESULTS] Models (sorted by score):")
|
||||
print(models)
|
||||
if out_csv:
|
||||
d = os.path.dirname(out_csv)
|
||||
if d:
|
||||
os.makedirs(d, exist_ok=True)
|
||||
models.to_csv(out_csv)
|
||||
print(f"\n[DONE] Saved model comparison to {out_csv}")
|
||||
return None
|
||||
|
||||
def lazy_eval_multi_seed(
|
||||
self,
|
||||
emb_dir: str,
|
||||
base_dir: str,
|
||||
out_dir: str = "./lazy_multi_seed",
|
||||
n_seeds: int = 10,
|
||||
verbose: int = 0,
|
||||
ignore_warnings: bool = True,
|
||||
targets: str = "all",
|
||||
) -> None:
|
||||
"""
|
||||
Run LazyPredict with n_seeds different train/test splits (from PPGBPLoader seeds).
|
||||
For each target (classification + regression), saves {target}.csv with mean and std
|
||||
for all metrics across seeds.
|
||||
|
||||
This can run for a long time (10 seeds x 16 targets x many models). For a quick test
|
||||
use e.g. --n_seeds 2 --targets hypertension
|
||||
"""
|
||||
if not HAS_LAZYPREDICT:
|
||||
raise ImportError("lazypredict is required. Install with: pip install lazypredict")
|
||||
|
||||
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
|
||||
if "subject_id" not in dataset.column_names:
|
||||
raise ValueError("Embeddings dataset missing 'subject_id' column.")
|
||||
|
||||
target_list = (
|
||||
[t.strip() for t in targets.split(",")] if targets != "all" else LAZY_EVAL_ALL_TARGETS
|
||||
)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
n_targets = len(target_list)
|
||||
print(f"[INFO] lazy_eval_multi_seed: {n_targets} targets, {n_seeds} seeds each (can take a long time).")
|
||||
|
||||
for ti, target in enumerate(target_list):
|
||||
if target not in dataset.column_names:
|
||||
print(f"[WARN] Skipping target '{target}' (not in dataset).")
|
||||
continue
|
||||
print(f"[INFO] Target '{target}' ({ti + 1}/{n_targets}) ...")
|
||||
|
||||
y_raw = list(dataset[target])
|
||||
valid = [
|
||||
i
|
||||
for i, v in enumerate(y_raw)
|
||||
if v is not None
|
||||
and not (isinstance(v, float) and pd.isna(v))
|
||||
and str(v).strip() != ""
|
||||
]
|
||||
if len(valid) == 0:
|
||||
print(f"[WARN] Skipping target '{target}' (no valid values).")
|
||||
continue
|
||||
|
||||
X = np.array(dataset["embedding"], dtype=np.float64)[valid]
|
||||
y_valid = [y_raw[i] for i in valid]
|
||||
subject_ids_valid = [dataset["subject_id"][i] for i in valid]
|
||||
|
||||
is_classification = target in LAZY_EVAL_CLASSIFICATION_TARGETS
|
||||
if target in LAZY_EVAL_REGRESSION_TARGETS and not is_classification:
|
||||
is_classification = False
|
||||
elif target not in LAZY_EVAL_REGRESSION_TARGETS:
|
||||
is_classification = True
|
||||
|
||||
run_dfs = []
|
||||
for seed in range(n_seeds):
|
||||
print(f"[INFO] seed {seed + 1}/{n_seeds} ...", flush=True)
|
||||
loader = PPGBPLoader(base_dir=base_dir, split="all", seed=seed)
|
||||
train_id_set = set(int(x) for x in loader.train_ids)
|
||||
test_id_set = set(int(x) for x in loader.test_ids)
|
||||
try:
|
||||
df = _run_lazy_one_split(
|
||||
X,
|
||||
y_valid,
|
||||
subject_ids_valid,
|
||||
train_id_set,
|
||||
test_id_set,
|
||||
is_classification,
|
||||
verbose=verbose,
|
||||
ignore_warnings=ignore_warnings,
|
||||
)
|
||||
run_dfs.append(df)
|
||||
except Exception as e:
|
||||
print(f"[WARN] Seed {seed} for target '{target}' failed: {e}")
|
||||
continue
|
||||
|
||||
if len(run_dfs) == 0:
|
||||
print(f"[WARN] No successful runs for target '{target}', skipping.")
|
||||
continue
|
||||
|
||||
combined = pd.concat(run_dfs, axis=0)
|
||||
metric_cols = [
|
||||
c for c in combined.columns
|
||||
if c != "Model" and np.issubdtype(combined[c].dtype, np.number)
|
||||
]
|
||||
agg_mean = combined.groupby(level=0).mean()
|
||||
agg_std = combined.groupby(level=0).std().fillna(0)
|
||||
|
||||
out = pd.DataFrame(index=agg_mean.index)
|
||||
out.index.name = "Model"
|
||||
for col in metric_cols:
|
||||
out[col + "_mean"] = agg_mean[col].values
|
||||
out[col + "_std"] = agg_std[col].values
|
||||
m = agg_mean[col].round(4).astype(str)
|
||||
s = agg_std[col].round(4).astype(str)
|
||||
out[col + "_mean_pm_std"] = m + " +/- " + s
|
||||
out_path = os.path.join(out_dir, target + ".csv")
|
||||
out.to_csv(out_path)
|
||||
print(f"[DONE] {target}: {len(run_dfs)} seeds -> {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(CLI)
|
||||
55
analysis/user_similarity/chronos2_ppgbp/run.sh
Executable file
55
analysis/user_similarity/chronos2_ppgbp/run.sh
Executable file
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
# =============================================================================
|
||||
# Chronos-2 PPG-BP Embedding Extraction & t-SNE Visualization
|
||||
# =============================================================================
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ---- Configuration (edit these variables) ------------------------------------
|
||||
BASE_DIR="/scratch/10608/aadharsh_aadhithya/repos/data/tsllm"
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
MODEL="amazon/chronos-2"
|
||||
POOLING="mean"
|
||||
SPLIT="all"
|
||||
PERPLEXITY=30.0
|
||||
|
||||
EMB_DIR="${BASE_DIR}/chronos2_embeddings"
|
||||
PLOT_DIR="${BASE_DIR}/chronos2_plots"
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
echo "============================================="
|
||||
echo "Chronos-2 PPG-BP Pipeline"
|
||||
echo "============================================="
|
||||
echo "Base dir: ${BASE_DIR}"
|
||||
echo "Model: ${MODEL}"
|
||||
echo "Pooling: ${POOLING}"
|
||||
echo "Split: ${SPLIT}"
|
||||
echo "Embeddings dir: ${EMB_DIR}"
|
||||
echo "Plots dir: ${PLOT_DIR}"
|
||||
echo "============================================="
|
||||
|
||||
# Step 1: Extract Chronos-2 embeddings from PPG signals
|
||||
echo ""
|
||||
echo "[Step 1/2] Extracting Chronos-2 embeddings..."
|
||||
python "${SCRIPT_DIR}/gen_plot.py" extract \
|
||||
--base_dir "${BASE_DIR}" \
|
||||
--out_dir "${EMB_DIR}" \
|
||||
--model "${MODEL}" \
|
||||
--pooling "${POOLING}" \
|
||||
--split "${SPLIT}"
|
||||
|
||||
# Step 2: Generate t-SNE plots colored by metadata
|
||||
echo ""
|
||||
echo "[Step 2/2] Generating t-SNE plots..."
|
||||
python "${SCRIPT_DIR}/gen_plot.py" plot \
|
||||
--emb_dir "${EMB_DIR}" \
|
||||
--out_dir "${PLOT_DIR}" \
|
||||
--perplexity "${PERPLEXITY}"
|
||||
|
||||
echo ""
|
||||
echo "============================================="
|
||||
echo "Done! Outputs:"
|
||||
echo " Embeddings: ${EMB_DIR}"
|
||||
echo " Plots: ${PLOT_DIR}"
|
||||
echo "============================================="
|
||||
@@ -60,7 +60,7 @@ from sentence_transformers import SentenceTransformer
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
SUBJECT_PATH = "LLM_Health/tsllm_personalization_icl/analysis/user_similarity/sbert_metadata_ppgbp/PPGBP_metadata.xlsx"
|
||||
SUBJECT_PATH = "/scratch/10608/aadharsh_aadhithya/repos/data/tsllm/PPGBP_metadata.xlsx"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -149,6 +149,45 @@ class SBERT_Metadata:
|
||||
"""
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
def discover_user_files(self, data_root: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
Discover all user data files under the data root.
|
||||
|
||||
Expected structure: data_root/0_subject/{user_id}_{session_id}.txt
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing data files
|
||||
|
||||
Returns:
|
||||
List of (user_id, session_id, file_path) tuples
|
||||
"""
|
||||
discovered_files = []
|
||||
|
||||
# Data is in 0_subject folder
|
||||
subject_root = os.path.join(data_root, "0_subject")
|
||||
|
||||
print(f"[DEBUG] Looking for user files in: {subject_root}")
|
||||
print(f"[DEBUG] Directory exists: {os.path.exists(subject_root)}")
|
||||
|
||||
# Use glob to find all .txt files
|
||||
for file_path in sorted(glob(os.path.join(subject_root, "*.txt"))):
|
||||
# Extract user_id and session_id from filename
|
||||
# Pattern: {user_id}_{session_id}.txt (e.g., 100_1.txt)
|
||||
filename = os.path.basename(file_path)
|
||||
name_without_ext = os.path.splitext(filename)[0]
|
||||
|
||||
parts = name_without_ext.split("_")
|
||||
if len(parts) >= 2:
|
||||
user_id = parts[0]
|
||||
session_id = parts[1]
|
||||
discovered_files.append((user_id, session_id, file_path))
|
||||
|
||||
print(f"[DEBUG] Found {len(discovered_files)} data files")
|
||||
if discovered_files:
|
||||
print(f"[DEBUG] First few files: {[(u, s) for u, s, _ in discovered_files[:5]]}")
|
||||
|
||||
return discovered_files
|
||||
|
||||
def textualize_metadata_ppg_bp(self,
|
||||
sex: Optional[str],
|
||||
age: Optional[int],
|
||||
@@ -236,9 +275,9 @@ class SBERT_Metadata:
|
||||
|
||||
# Create sentence from metadata
|
||||
if age is not None:
|
||||
sentence = f"This is the information of the user, sex: {sex_text}, age: {age_text}, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
|
||||
sentence = f"This is the information of the user, sex: {sex_text}, age: {age_text}, height: {height_text}, weight: {weight_text}, systolic blood pressure: {sbp_text}, diastolic blood pressure: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
|
||||
else:
|
||||
sentence = f"This is the information of the user, sex: {sex_text}, age: unknown, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
|
||||
sentence = f"This is the information of the user, sex: {sex_text}, age: unknown, height: {height_text}, weight: {weight_text}, systolic blood pressure: {sbp_text}, diastolic blood pressure: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
|
||||
|
||||
return sentence
|
||||
|
||||
@@ -321,11 +360,15 @@ class SBERT_Metadata:
|
||||
print(f"[INFO] Loading subject metadata from: {subject_path}")
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
|
||||
print(f"[DEBUG] Sample metadata keys: {list(subject_metadata.keys())[:10]}")
|
||||
|
||||
# Discover user data files
|
||||
user_files = self.discover_user_files(data_root)
|
||||
print(f"[INFO] Discovered {len(user_files)} data files")
|
||||
|
||||
all_embeddings = []
|
||||
all_user_ids = []
|
||||
# all_session_ids = []
|
||||
all_session_ids = []
|
||||
all_idxs = []
|
||||
all_labels = []
|
||||
all_sexes = []
|
||||
@@ -337,26 +380,24 @@ class SBERT_Metadata:
|
||||
all_heart_rates = []
|
||||
all_bmis = []
|
||||
all_hypertensions = []
|
||||
all_age_categories = []
|
||||
all_height_categories = []
|
||||
all_weight_categories = []
|
||||
all_systolic_categories = []
|
||||
all_diastolic_categories = []
|
||||
all_heart_rate_categories = []
|
||||
all_bmi_categories = []
|
||||
all_file_paths = []
|
||||
|
||||
# Collect metadata for all samples first
|
||||
for user_id, session_path in session_paths:
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(session_path)
|
||||
# Shuffle dataset for randomness
|
||||
dataset = dataset.shuffle(seed=0)
|
||||
# Filter by sleep stage label if specified
|
||||
if label is not None and label != "all":
|
||||
dataset = dataset.filter(lambda x: x["label"] == label)
|
||||
num_samples = len(dataset)
|
||||
|
||||
if num_samples == 0:
|
||||
continue
|
||||
|
||||
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
|
||||
# Collect metadata for all files
|
||||
matched_count = 0
|
||||
unmatched_users = set()
|
||||
for idx, (user_id, session_id, file_path) in enumerate(user_files):
|
||||
|
||||
# Get metadata for this user
|
||||
# Convert user_id to string format that matches metadata keys
|
||||
user_id_str = str(int(user_id))
|
||||
if idx == 0:
|
||||
print(f"[DEBUG] First file user_id: '{user_id}' -> lookup key: '{user_id_str}'")
|
||||
try:
|
||||
sex = subject_metadata[user_id_str]["sex"]
|
||||
age = subject_metadata[user_id_str]["age"]
|
||||
@@ -364,11 +405,11 @@ class SBERT_Metadata:
|
||||
weight = subject_metadata[user_id_str]["weight"]
|
||||
sbp = subject_metadata[user_id_str]["sbp"]
|
||||
dbp = subject_metadata[user_id_str]["dbp"]
|
||||
heart_rate = subject_metadata[user_id_str]["heart_rate"]
|
||||
heart_rate = subject_metadata[user_id_str]["hr"]
|
||||
bmi = subject_metadata[user_id_str]["bmi"]
|
||||
hypertension = subject_metadata[user_id_str]["hypertension"]
|
||||
|
||||
|
||||
matched_count += 1
|
||||
except KeyError:
|
||||
sex = None
|
||||
age = None
|
||||
@@ -379,16 +420,126 @@ class SBERT_Metadata:
|
||||
heart_rate = None
|
||||
bmi = None
|
||||
hypertension = None
|
||||
print(f"[WARN] No metadata found for user_id: {user_id_str}")
|
||||
unmatched_users.add(user_id_str)
|
||||
|
||||
# Collect all samples for this user/session
|
||||
for i in range(num_samples):
|
||||
all_user_ids.append(str(dataset["user_id"][i]))
|
||||
all_session_ids.append(str(dataset["session_id"][i]))
|
||||
all_idxs.append(int(dataset["idx"][i]))
|
||||
all_labels.append(str(dataset["label"][i]))
|
||||
all_ages.append(age)
|
||||
all_sexes.append(sex)
|
||||
# Collect data for this file
|
||||
all_user_ids.append(user_id)
|
||||
all_session_ids.append(session_id)
|
||||
all_idxs.append(idx)
|
||||
all_labels.append(label if label else "all")
|
||||
all_ages.append(age)
|
||||
all_sexes.append(sex)
|
||||
all_heights.append(height)
|
||||
all_weights.append(weight)
|
||||
all_systolics.append(sbp)
|
||||
all_diastolics.append(dbp)
|
||||
all_heart_rates.append(heart_rate)
|
||||
all_bmis.append(bmi)
|
||||
all_hypertensions.append(hypertension)
|
||||
# Categorize age
|
||||
if age is not None:
|
||||
if age <= 35:
|
||||
age_category = "Young (21-35)"
|
||||
elif age <= 55:
|
||||
age_category = "Middle (36-55)"
|
||||
elif age <= 75:
|
||||
age_category = "Senior (56-75)"
|
||||
else:
|
||||
age_category = "Old (76+)"
|
||||
else:
|
||||
age_category = None
|
||||
all_age_categories.append(age_category)
|
||||
|
||||
# Categorize height (cm): based on quartiles 145-196
|
||||
if height is not None:
|
||||
if height <= 157:
|
||||
height_category = "Short (145-157)"
|
||||
elif height <= 170:
|
||||
height_category = "Below Avg (158-170)"
|
||||
elif height <= 183:
|
||||
height_category = "Above Avg (171-183)"
|
||||
else:
|
||||
height_category = "Tall (184+)"
|
||||
else:
|
||||
height_category = None
|
||||
all_height_categories.append(height_category)
|
||||
|
||||
# Categorize weight (kg): based on quartiles 36-103
|
||||
if weight is not None:
|
||||
if weight <= 52:
|
||||
weight_category = "Light (36-52)"
|
||||
elif weight <= 69:
|
||||
weight_category = "Below Avg (53-69)"
|
||||
elif weight <= 86:
|
||||
weight_category = "Above Avg (70-86)"
|
||||
else:
|
||||
weight_category = "Heavy (87+)"
|
||||
else:
|
||||
weight_category = None
|
||||
all_weight_categories.append(weight_category)
|
||||
|
||||
# Categorize systolic BP (mmHg): clinical categories
|
||||
if sbp is not None:
|
||||
if sbp < 90:
|
||||
systolic_category = "Low (<90)"
|
||||
elif sbp < 120:
|
||||
systolic_category = "Normal (90-119)"
|
||||
elif sbp < 140:
|
||||
systolic_category = "Elevated (120-139)"
|
||||
else:
|
||||
systolic_category = "High (140+)"
|
||||
else:
|
||||
systolic_category = None
|
||||
all_systolic_categories.append(systolic_category)
|
||||
|
||||
# Categorize diastolic BP (mmHg): clinical categories
|
||||
if dbp is not None:
|
||||
if dbp < 60:
|
||||
diastolic_category = "Low (<60)"
|
||||
elif dbp < 80:
|
||||
diastolic_category = "Normal (60-79)"
|
||||
elif dbp < 90:
|
||||
diastolic_category = "Elevated (80-89)"
|
||||
else:
|
||||
diastolic_category = "High (90+)"
|
||||
else:
|
||||
diastolic_category = None
|
||||
all_diastolic_categories.append(diastolic_category)
|
||||
|
||||
# Categorize heart rate (bpm): based on quartiles 52-106
|
||||
if heart_rate is not None:
|
||||
if heart_rate <= 65:
|
||||
hr_category = "Low (52-65)"
|
||||
elif heart_rate <= 79:
|
||||
hr_category = "Normal (66-79)"
|
||||
elif heart_rate <= 93:
|
||||
hr_category = "Elevated (80-93)"
|
||||
else:
|
||||
hr_category = "High (94+)"
|
||||
else:
|
||||
hr_category = None
|
||||
all_heart_rate_categories.append(hr_category)
|
||||
|
||||
# Categorize BMI (kg/m²): clinical categories
|
||||
if bmi is not None:
|
||||
if bmi < 18.5:
|
||||
bmi_category = "Underweight (<18.5)"
|
||||
elif bmi < 25:
|
||||
bmi_category = "Normal (18.5-24.9)"
|
||||
elif bmi < 30:
|
||||
bmi_category = "Overweight (25-29.9)"
|
||||
else:
|
||||
bmi_category = "Obese (30+)"
|
||||
else:
|
||||
bmi_category = None
|
||||
all_bmi_categories.append(bmi_category)
|
||||
|
||||
all_file_paths.append(file_path)
|
||||
|
||||
print(f"[INFO] Collected metadata for {len(all_user_ids)} samples")
|
||||
print(f"[INFO] Matched {matched_count} files with metadata, {len(unmatched_users)} unique users unmatched")
|
||||
if unmatched_users:
|
||||
print(f"[DEBUG] Unmatched user IDs: {sorted(unmatched_users)[:10]}{'...' if len(unmatched_users) > 10 else ''}")
|
||||
|
||||
# Generate embeddings from metadata in batches
|
||||
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
|
||||
@@ -397,9 +548,20 @@ class SBERT_Metadata:
|
||||
|
||||
batch_sexes = all_sexes[batch_start:batch_end]
|
||||
batch_ages = all_ages[batch_start:batch_end]
|
||||
batch_heights = all_heights[batch_start:batch_end]
|
||||
batch_weights = all_weights[batch_start:batch_end]
|
||||
batch_systolics = all_systolics[batch_start:batch_end]
|
||||
batch_diastolics = all_diastolics[batch_start:batch_end]
|
||||
batch_heart_rates = all_heart_rates[batch_start:batch_end]
|
||||
batch_bmis = all_bmis[batch_start:batch_end]
|
||||
batch_hypertensions = all_hypertensions[batch_start:batch_end]
|
||||
|
||||
# Compute embeddings from metadata
|
||||
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
|
||||
embeddings = self.compute_embedding_from_metadata(
|
||||
batch_sexes, batch_ages, batch_heights, batch_weights,
|
||||
batch_systolics, batch_diastolics, batch_heart_rates,
|
||||
batch_bmis, batch_hypertensions
|
||||
)
|
||||
|
||||
# Collect embeddings
|
||||
for i in range(embeddings.shape[0]):
|
||||
@@ -408,18 +570,26 @@ class SBERT_Metadata:
|
||||
# Create HuggingFace Dataset
|
||||
result_dataset = Dataset.from_dict({
|
||||
"user_id": all_user_ids,
|
||||
# "session_id": all_session_ids,
|
||||
"session_id": all_session_ids,
|
||||
"idx": all_idxs,
|
||||
"label": all_labels,
|
||||
"sex": all_sexes,
|
||||
"age": all_ages,
|
||||
"age_category": all_age_categories,
|
||||
"height": all_heights,
|
||||
"height_category": all_height_categories,
|
||||
"weight": all_weights,
|
||||
"weight_category": all_weight_categories,
|
||||
"systolic": all_systolics,
|
||||
"systolic_category": all_systolic_categories,
|
||||
"diastolic": all_diastolics,
|
||||
"diastolic_category": all_diastolic_categories,
|
||||
"heart_rate": all_heart_rates,
|
||||
"heart_rate_category": all_heart_rate_categories,
|
||||
"bmi": all_bmis,
|
||||
"bmi_category": all_bmi_categories,
|
||||
"hypertension": all_hypertensions,
|
||||
"file_path": all_file_paths,
|
||||
"embedding": all_embeddings,
|
||||
})
|
||||
|
||||
@@ -451,16 +621,19 @@ class SBERT_Metadata:
|
||||
dataset.save_to_disk(output_dir)
|
||||
|
||||
print(f"[DONE] Saved embeddings dataset: {output_dir}")
|
||||
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
|
||||
if len(dataset) > 0:
|
||||
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
|
||||
else:
|
||||
print(f"[WARN] Dataset is empty - no samples were processed")
|
||||
|
||||
@staticmethod
|
||||
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||
"""
|
||||
Load saved embeddings dataset from disk.
|
||||
|
||||
|
||||
Args:
|
||||
embedding_dir: Directory path containing the saved HuggingFace dataset
|
||||
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with embeddings and metadata
|
||||
"""
|
||||
@@ -469,6 +642,94 @@ class SBERT_Metadata:
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Save metadata embeddings by user ID (for PPGBP dataloader)
|
||||
# =============================================================================
|
||||
|
||||
def _normalize_sex_for_sbert(sex: Any) -> Optional[int]:
|
||||
"""Convert metadata sex to int for SBERT textualize (1=Female, 0=Male)."""
|
||||
if sex is None or (isinstance(sex, float) and np.isnan(sex)):
|
||||
return None
|
||||
s = str(sex).strip().upper()
|
||||
if s in ("F", "1"):
|
||||
return 1
|
||||
if s in ("M", "0"):
|
||||
return 0
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_hypertension_for_sbert(ht: Any) -> Optional[int]:
|
||||
"""Convert metadata hypertension to int if numeric."""
|
||||
if ht is None or (isinstance(ht, float) and np.isnan(ht)):
|
||||
return None
|
||||
try:
|
||||
return int(float(ht))
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def save_metadata_embeddings_by_userid(
|
||||
data_root: str,
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
embeddings_root: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Compute SBERT metadata embeddings for each user and save one file per user
|
||||
under embeddings_root as {user_id}.npy (e.g. 100.npy).
|
||||
|
||||
Uses PPGBP_METADATA_EMBEDDINGS_ROOT from environment if embeddings_root
|
||||
is not provided. Load .env from project root if present so the env var is set.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing 0_subject/ with PPG-BP data.
|
||||
subject_path: Path to PPGBP metadata xlsx (subject_ID, age, sex, etc.).
|
||||
embeddings_root: Directory to write {user_id}.npy files. Defaults to
|
||||
os.environ["PPGBP_METADATA_EMBEDDINGS_ROOT"].
|
||||
"""
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
# Load .env from repo root (parent of analysis/)
|
||||
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
load_dotenv(os.path.join(_repo_root, ".env"))
|
||||
except ImportError:
|
||||
pass
|
||||
root = embeddings_root or os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT")
|
||||
if not root:
|
||||
raise ValueError(
|
||||
"embeddings_root not provided and PPGBP_METADATA_EMBEDDINGS_ROOT not set. "
|
||||
"Set the env var or pass embeddings_root."
|
||||
)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
embedder = SBERT_Metadata()
|
||||
|
||||
user_ids = sorted(subject_metadata.keys())
|
||||
if not user_ids:
|
||||
raise ValueError(f"No subjects found in {subject_path}")
|
||||
|
||||
for user_id in user_ids:
|
||||
meta = subject_metadata[user_id]
|
||||
sex = _normalize_sex_for_sbert(meta.get("sex"))
|
||||
age = meta.get("age")
|
||||
height = meta.get("height")
|
||||
weight = meta.get("weight")
|
||||
sbp = meta.get("sbp")
|
||||
dbp = meta.get("dbp")
|
||||
hr = meta.get("hr")
|
||||
bmi = meta.get("bmi")
|
||||
hypertension = _normalize_hypertension_for_sbert(meta.get("hypertension"))
|
||||
if age is not None and isinstance(age, float) and np.isnan(age):
|
||||
age = None
|
||||
|
||||
emb = embedder.compute_embedding_from_metadata(
|
||||
[sex], [age], [height], [weight], [sbp], [dbp], [hr], [bmi], [hypertension]
|
||||
)
|
||||
path = os.path.join(root, f"{user_id}.npy")
|
||||
np.save(path, emb[0])
|
||||
print(f"[DONE] Saved {len(user_ids)} metadata embeddings under {root}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading Utilities
|
||||
# =============================================================================
|
||||
@@ -625,11 +886,242 @@ def create_scatter_plot_by_age(
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
print(f"[DONE] Saved plot: {d}")
|
||||
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
|
||||
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
|
||||
|
||||
|
||||
def create_scatter_plot_continuous(
|
||||
coordinates: np.ndarray,
|
||||
values: np.ndarray,
|
||||
title: str,
|
||||
output_path: str,
|
||||
label: str = "Value",
|
||||
cmap: str = "viridis"
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot colored by a continuous variable.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
values: Values for each point (array, may contain None values)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot
|
||||
label: Label for the colorbar
|
||||
cmap: Colormap name
|
||||
"""
|
||||
# Filter out points with missing data
|
||||
values_float = np.array([float(v) if v is not None else np.nan for v in values])
|
||||
valid_mask = ~np.isnan(values_float)
|
||||
valid_coords = coordinates[valid_mask]
|
||||
valid_values = values_float[valid_mask]
|
||||
|
||||
if len(valid_values) == 0:
|
||||
print(f"[WARN] No valid {label} data found. Skipping plot: {output_path}")
|
||||
return
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
scatter = ax.scatter(
|
||||
valid_coords[:, 0],
|
||||
valid_coords[:, 1],
|
||||
c=valid_values,
|
||||
cmap=cmap,
|
||||
alpha=0.7,
|
||||
s=50
|
||||
)
|
||||
cbar = plt.colorbar(scatter, ax=ax)
|
||||
cbar.set_label(label)
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
print(f"[INFO] {label} range: {valid_values.min():.1f} - {valid_values.max():.1f}")
|
||||
|
||||
|
||||
def create_scatter_plot_categorical(
|
||||
coordinates: np.ndarray,
|
||||
categories: np.ndarray,
|
||||
title: str,
|
||||
output_path: str,
|
||||
label: str = "Category"
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot colored by a categorical variable.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
categories: Category values for each point (array, may contain None values)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot
|
||||
label: Label for the legend
|
||||
"""
|
||||
# Filter out points with missing data
|
||||
valid_mask = np.array([c is not None and str(c) != 'None' for c in categories])
|
||||
valid_coords = coordinates[valid_mask]
|
||||
valid_cats = np.array(categories)[valid_mask]
|
||||
|
||||
if len(valid_cats) == 0:
|
||||
print(f"[WARN] No valid {label} data found. Skipping plot: {output_path}")
|
||||
return
|
||||
|
||||
# Get unique categories and assign colors
|
||||
unique_cats = sorted(set(str(c) for c in valid_cats))
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_cats)))
|
||||
cat_to_color = {cat: colors[i] for i, cat in enumerate(unique_cats)}
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
for cat in unique_cats:
|
||||
mask = np.array([str(c) == cat for c in valid_cats])
|
||||
ax.scatter(
|
||||
valid_coords[mask, 0],
|
||||
valid_coords[mask, 1],
|
||||
c=[cat_to_color[cat]],
|
||||
label=f"{cat}",
|
||||
alpha=0.7,
|
||||
s=50
|
||||
)
|
||||
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
ax.legend(title=label, loc='best')
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
print(f"[INFO] {label} categories: {unique_cats}")
|
||||
|
||||
|
||||
def gen_2d_plot(
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
out_dir: str = "./plots",
|
||||
add_jitter: bool = True,
|
||||
jitter_amount: float = 0.15
|
||||
) -> None:
|
||||
"""
|
||||
Generate a simple 2D scatter plot with hypertension category on x-axis
|
||||
and systolic blood pressure on y-axis.
|
||||
|
||||
This function reads metadata directly from the Excel file and creates
|
||||
a scatter plot without any embedding or dimensionality reduction.
|
||||
|
||||
Args:
|
||||
subject_path: Path to the PPGBP_metadata.xlsx file
|
||||
out_dir: Output directory for the plot
|
||||
add_jitter: Whether to add horizontal jitter for better visibility (default: True)
|
||||
jitter_amount: Amount of jitter to add (default: 0.15)
|
||||
"""
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Load metadata
|
||||
print(f"[INFO] Loading metadata from: {subject_path}")
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
|
||||
|
||||
# Extract hypertension and systolic BP values
|
||||
hypertensions = []
|
||||
systolics = []
|
||||
|
||||
for subject_id, metadata in subject_metadata.items():
|
||||
hypertension = metadata.get("hypertension")
|
||||
sbp = metadata.get("sbp")
|
||||
|
||||
# Only include if both values are present
|
||||
if hypertension is not None and sbp is not None:
|
||||
hypertensions.append(str(hypertension))
|
||||
systolics.append(sbp)
|
||||
|
||||
if len(hypertensions) == 0:
|
||||
print("[WARN] No valid data found with both hypertension and systolic BP values.")
|
||||
return
|
||||
|
||||
print(f"[INFO] Found {len(hypertensions)} subjects with valid data")
|
||||
|
||||
# Map hypertension categories to numeric x positions
|
||||
# Expected categories: "0" (Normal), "1" (Prehypertension), "2" (Stage 1), "3" (Stage 2)
|
||||
unique_cats = sorted(set(hypertensions))
|
||||
cat_to_x = {cat: i for i, cat in enumerate(unique_cats)}
|
||||
|
||||
# Create x positions (with optional jitter for visibility)
|
||||
x_positions = np.array([cat_to_x[h] for h in hypertensions], dtype=float)
|
||||
if add_jitter:
|
||||
x_positions += np.random.uniform(-jitter_amount, jitter_amount, size=len(x_positions))
|
||||
|
||||
y_values = np.array(systolics)
|
||||
|
||||
# Assign colors by category
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_cats)))
|
||||
cat_to_color = {cat: colors[i] for i, cat in enumerate(unique_cats)}
|
||||
point_colors = [cat_to_color[h] for h in hypertensions]
|
||||
|
||||
# Create hypertension labels for legend
|
||||
hypertension_labels = {
|
||||
"0": "Normal",
|
||||
"1": "Prehypertension",
|
||||
"2": "Stage 1 Hypertension",
|
||||
"3": "Stage 2 Hypertension"
|
||||
}
|
||||
|
||||
# Create the plot
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Plot each category separately for legend
|
||||
for cat in unique_cats:
|
||||
mask = np.array([h == cat for h in hypertensions])
|
||||
cat_label = hypertension_labels.get(cat, cat)
|
||||
ax.scatter(
|
||||
x_positions[mask],
|
||||
y_values[mask],
|
||||
c=[cat_to_color[cat]],
|
||||
label=cat_label,
|
||||
alpha=0.7,
|
||||
s=60,
|
||||
edgecolors='white',
|
||||
linewidth=0.5
|
||||
)
|
||||
|
||||
# Configure axes
|
||||
ax.set_xticks(range(len(unique_cats)))
|
||||
ax.set_xticklabels([hypertension_labels.get(cat, cat) for cat in unique_cats], rotation=15, ha='right')
|
||||
ax.set_xlabel("Hypertension Category")
|
||||
ax.set_ylabel("Systolic Blood Pressure (mmHg)")
|
||||
ax.set_title("Systolic Blood Pressure by Hypertension Category")
|
||||
|
||||
# Add horizontal grid lines
|
||||
ax.yaxis.grid(True, linestyle='--', alpha=0.7)
|
||||
ax.set_axisbelow(True)
|
||||
|
||||
# Add legend
|
||||
ax.legend(title="Hypertension", loc='upper left')
|
||||
|
||||
# Save the plot
|
||||
output_path = os.path.join(out_dir, "hypertension_vs_systolic.pdf")
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
print(f"[INFO] Hypertension categories: {unique_cats}")
|
||||
print(f"[INFO] Systolic BP range: {y_values.min():.0f} - {y_values.max():.0f} mmHg")
|
||||
|
||||
# Print category statistics
|
||||
print("\n[INFO] Category statistics:")
|
||||
for cat in unique_cats:
|
||||
cat_mask = np.array([h == cat for h in hypertensions])
|
||||
cat_systolics = y_values[cat_mask]
|
||||
cat_label = hypertension_labels.get(cat, cat)
|
||||
print(f" {cat_label}: n={len(cat_systolics)}, "
|
||||
f"mean={cat_systolics.mean():.1f}, "
|
||||
f"std={cat_systolics.std():.1f}, "
|
||||
f"range=[{cat_systolics.min():.0f}-{cat_systolics.max():.0f}]")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
@@ -729,19 +1221,81 @@ class CLI:
|
||||
# Extract embeddings as numpy array for t-SNE
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
|
||||
# Extract ages for coloring
|
||||
ages = np.array(dataset["age"])
|
||||
|
||||
# Reduce to 2D with t-SNE
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||
|
||||
# Generate visualization colored by age
|
||||
create_scatter_plot_by_age(
|
||||
coordinates_2d,
|
||||
ages,
|
||||
"t-SNE Visualization (Colored by Age)",
|
||||
os.path.join(out_dir, "tsne_by_age.pdf")
|
||||
)
|
||||
# Generate visualizations for continuous variables
|
||||
continuous_vars = [
|
||||
("age", "Age (years)", "viridis"),
|
||||
("height", "Height (cm)", "plasma"),
|
||||
("weight", "Weight (kg)", "cividis"),
|
||||
("systolic", "Systolic BP (mmHg)", "Reds"),
|
||||
("diastolic", "Diastolic BP (mmHg)", "Blues"),
|
||||
("heart_rate", "Heart Rate (bpm)", "Oranges"),
|
||||
("bmi", "BMI (kg/m²)", "Greens"),
|
||||
]
|
||||
|
||||
for var_name, var_label, cmap in continuous_vars:
|
||||
if var_name in dataset.column_names:
|
||||
values = np.array(dataset[var_name])
|
||||
create_scatter_plot_continuous(
|
||||
coordinates_2d,
|
||||
values,
|
||||
f"t-SNE Visualization (Colored by {var_label})",
|
||||
os.path.join(out_dir, f"tsne_by_{var_name}.pdf"),
|
||||
label=var_label,
|
||||
cmap=cmap
|
||||
)
|
||||
|
||||
# Generate visualizations for categorical variables
|
||||
categorical_vars = [
|
||||
("sex", "Sex"),
|
||||
("hypertension", "Hypertension"),
|
||||
("age_category", "Age Category"),
|
||||
("height_category", "Height Category"),
|
||||
("weight_category", "Weight Category"),
|
||||
("systolic_category", "Systolic BP Category"),
|
||||
("diastolic_category", "Diastolic BP Category"),
|
||||
("heart_rate_category", "Heart Rate Category"),
|
||||
("bmi_category", "BMI Category"),
|
||||
]
|
||||
|
||||
for var_name, var_label in categorical_vars:
|
||||
if var_name in dataset.column_names:
|
||||
categories = np.array(dataset[var_name])
|
||||
create_scatter_plot_categorical(
|
||||
coordinates_2d,
|
||||
categories,
|
||||
f"t-SNE Visualization (Colored by {var_label})",
|
||||
os.path.join(out_dir, f"tsne_by_{var_name}.pdf"),
|
||||
label=var_label
|
||||
)
|
||||
|
||||
def plot_2d(
|
||||
self,
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
out_dir: str = "./plots",
|
||||
add_jitter: bool = True,
|
||||
jitter_amount: float = 0.15
|
||||
) -> None:
|
||||
"""
|
||||
Generate a simple 2D scatter plot with hypertension category on x-axis
|
||||
and systolic blood pressure on y-axis.
|
||||
|
||||
This reads metadata directly from the Excel file and creates a simple
|
||||
scatter plot without embeddings or dimensionality reduction.
|
||||
|
||||
Args:
|
||||
subject_path: Path to the PPGBP_metadata.xlsx file
|
||||
out_dir: Output directory for the plot
|
||||
add_jitter: Whether to add horizontal jitter for better visibility (default: True)
|
||||
jitter_amount: Amount of jitter to add (default: 0.15)
|
||||
|
||||
Usage:
|
||||
python gen_plot.py plot_2d --out_dir ./plots
|
||||
python gen_plot.py plot_2d --subject_path /path/to/metadata.xlsx --out_dir ./plots
|
||||
"""
|
||||
gen_2d_plot(subject_path, out_dir, add_jitter, jitter_amount)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
1029
sample_case.py
Normal file
1029
sample_case.py
Normal file
File diff suppressed because it is too large
Load Diff
32
sc/config/ppgbp_recruiter.yaml
Normal file
32
sc/config/ppgbp_recruiter.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
# PPGBP Recruiter (MAB-based example set selection) config
|
||||
# Usage: python -m sc.run_recruiter sc/config/ppgbp_recruiter.yaml
|
||||
# When model_path is null: use ground-truth scorer (arm_id -> score), no LLM; run gt_steps to test MAB convergence.
|
||||
|
||||
|
||||
data_path: $PPGBP_DATA_ROOT
|
||||
feature_root: $DATA_INDEX_ROOT/metadata_onehot # index.json + embeddings.npy
|
||||
log_path: ./logs/ppgbp_recruiter
|
||||
model_path: null # set to a HF model path to use real LLM + self_certainty
|
||||
|
||||
queue_size: 5
|
||||
recruit_size: 2
|
||||
n_way: 2
|
||||
k_shot: 3
|
||||
num_arms: 50
|
||||
|
||||
# When model_path is null: deterministic reward = Gaussian over arm index, peak at gt_best_arm_id
|
||||
gt_steps: 200
|
||||
gt_best_arm_id: 20
|
||||
gt_sigma: 5.0
|
||||
|
||||
n_das: 5
|
||||
delta: 0.05
|
||||
epsilon: 0.11
|
||||
|
||||
temperature: 0.7
|
||||
max_new_tokens: 256
|
||||
num_seeds: 1
|
||||
seed: 42
|
||||
|
||||
task_info: "Classify the subject from PPG-BP metadata."
|
||||
classes_info: ["Normal", "Prehypertension", "Stage 1 hypertension", "Stage 2 hypertension" ]
|
||||
@@ -65,14 +65,9 @@ temperature: 0.0
|
||||
# ------------------------------------------------------------------------------
|
||||
# Multiple Ollama instances provide model diversity even with T=0
|
||||
models:
|
||||
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
|
||||
- hf:openai/gpt-oss-20b
|
||||
# Add more entries for model pool (same model, multiple instances for self-consistency)
|
||||
# - hf:/path/to/gpt-oss-20b
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Output Configuration
|
||||
|
||||
207
sc/core/agent.py
207
sc/core/agent.py
@@ -1,147 +1,142 @@
|
||||
"""
|
||||
Base agent for single-turn LLM inference with structured JSON parsing.
|
||||
|
||||
Each invoke call is stateless: system prompt + user prompt -> (text, logits).
|
||||
Message format: {"role": "user"|"assistant"|"system", "content": str}
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
||||
import torch
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from core.model import AsyncModelPool, ChatMessage
|
||||
|
||||
|
||||
class Agent:
|
||||
"""Wraps an AsyncModelPool for single-turn inference."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
model_pool,
|
||||
log_path,
|
||||
):
|
||||
name: str,
|
||||
model_pool: AsyncModelPool,
|
||||
log_path: str,
|
||||
system_message: str = "",
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.model_pool = model_pool
|
||||
self.log_path = log_path
|
||||
self.system_message = system_message
|
||||
|
||||
# Logging directories
|
||||
self.root_log_path = log_path
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
self.agent_log_path = os.path.join(log_path, name)
|
||||
os.makedirs(self.agent_log_path, exist_ok=True)
|
||||
|
||||
self.long_term_memory = []
|
||||
self.short_term_memory = []
|
||||
self.volatile_memory = []
|
||||
# -----------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def log(self, message, local=True):
|
||||
path = os.path.join(self.root_log_path, "log.txt")
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
message_type = "UNKNOWN"
|
||||
if isinstance(message, SystemMessage):
|
||||
message_type = "SYSTEM"
|
||||
if isinstance(message, HumanMessage):
|
||||
message_type = "HUMAN"
|
||||
if isinstance(message, AIMessage):
|
||||
message_type = "AI"
|
||||
content = message.content.strip()
|
||||
name = self.name
|
||||
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
|
||||
if local:
|
||||
local_path = os.path.join(self.agent_log_path, "log.txt")
|
||||
with open(local_path, "a", encoding="utf-8") as f:
|
||||
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
|
||||
def log(self, message: ChatMessage) -> None:
|
||||
"""Append a message to the global and agent-local log."""
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "").strip()
|
||||
entry = f"[{self.name}] [{role}]\n{content}\n\n\n"
|
||||
|
||||
def update_memory(self):
|
||||
self.long_term_memory.extend(self.short_term_memory)
|
||||
self.clean_short_term_memory()
|
||||
self.clean_volatile_memory()
|
||||
for path in (
|
||||
os.path.join(self.root_log_path, "log.txt"),
|
||||
os.path.join(self.agent_log_path, "log.txt"),
|
||||
):
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(entry)
|
||||
|
||||
def clean_short_term_memory(self):
|
||||
self.short_term_memory = []
|
||||
# -----------------------------------------------------------------
|
||||
# JSON helpers
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def clean_volatile_memory(self):
|
||||
self.volatile_memory = []
|
||||
|
||||
def clean_long_term_memory(self):
|
||||
self.long_term_memory = []
|
||||
|
||||
def clean_json_text(self, text):
|
||||
@staticmethod
|
||||
def _clean_json_text(text: str) -> str:
|
||||
"""Normalise common LLM quirks so the string is valid JSON."""
|
||||
text = text.strip()
|
||||
text = text.replace("‘", "'").replace("’", "'")
|
||||
text = text.replace("“", "'").replace("”", "'")
|
||||
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text)
|
||||
text = re.sub(r",\s*}", "}", text)
|
||||
text = re.sub(r",\s*]", "]", text)
|
||||
text = "".join(ch for ch in text if ch.isprintable())
|
||||
text = text.replace("][", ",")
|
||||
text = text.replace("\u2018", "'").replace("\u2019", "'") # smart single quotes
|
||||
text = text.replace("\u201c", "'").replace("\u201d", "'") # smart double quotes
|
||||
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text) # escape stray backslashes
|
||||
text = re.sub(r",\s*}", "}", text) # trailing comma in object
|
||||
text = re.sub(r",\s*]", "]", text) # trailing comma in array
|
||||
text = "".join(ch for ch in text if ch.isprintable()) # strip control chars
|
||||
text = text.replace("][", ",") # merge adjacent arrays
|
||||
return text
|
||||
|
||||
def safe_parse_json(self, text):
|
||||
def safe_parse_json(self, text: str) -> Optional[dict]:
|
||||
"""Best-effort extraction of a JSON object from an LLM response."""
|
||||
if not text:
|
||||
return None
|
||||
text = text.strip()
|
||||
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if not match and not text.endswith("}"):
|
||||
match = re.search(r"\{.*\}", text + "}", re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
return json.loads(self._clean_json_text(match.group(0)))
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
elif not text.endswith("}"):
|
||||
text += "}"
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
print("[!] JSON parse failed")
|
||||
|
||||
print("[!] JSON parse failed: no object found")
|
||||
return None
|
||||
|
||||
async def validate_response(self, response, fields, volatile=False):
|
||||
if (
|
||||
not response
|
||||
or not isinstance(response, dict)
|
||||
or not all(field in response for field in fields)
|
||||
):
|
||||
print("[!] The JSON failed to be parsed. Trying again.")
|
||||
content = (
|
||||
"Failed to parse the JSON from the previous response. Please try again."
|
||||
# -----------------------------------------------------------------
|
||||
# Response validation
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
async def validate_response(
|
||||
self, response: Optional[dict], fields: List[str]
|
||||
) -> Optional[dict]:
|
||||
"""Validate that *response* is a dict containing all *fields*; retry once if not."""
|
||||
if not isinstance(response, dict) or not all(f in response for f in fields):
|
||||
print("[!] JSON validation failed. Retrying...")
|
||||
text, _ = await self.invoke(
|
||||
"Failed to parse the JSON from the previous response. Please try again.",
|
||||
)
|
||||
response = await self.invoke(content, volatile=volatile)
|
||||
response = self.safe_parse_json(response)
|
||||
if (
|
||||
not response
|
||||
or not isinstance(response, dict)
|
||||
or not all(field in response for field in fields)
|
||||
):
|
||||
response = self.safe_parse_json(text)
|
||||
|
||||
if not isinstance(response, dict) or not all(f in response for f in fields):
|
||||
print("[!] Retry failed.")
|
||||
return None
|
||||
return response
|
||||
|
||||
def get_last_response(self):
|
||||
if len(self.long_term_memory) >= 2:
|
||||
last_msg = self.long_term_memory[-1]
|
||||
if isinstance(last_msg, AIMessage):
|
||||
return self.safe_parse_json(last_msg.content)
|
||||
return None
|
||||
# -----------------------------------------------------------------
|
||||
# Core invoke
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def set_system_message(self, content, local=True):
|
||||
system_message = SystemMessage(content=content)
|
||||
self.log(system_message, local)
|
||||
self.long_term_memory.append(system_message)
|
||||
async def invoke(
|
||||
self,
|
||||
content: str,
|
||||
) -> Tuple[Optional[str], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Single-turn call: [system_message] + user content -> (text, logits).
|
||||
|
||||
Returns:
|
||||
text: Generated reply string (or None on error).
|
||||
logits: Tensor of shape (1, num_tokens, vocab_size) on CPU (or None).
|
||||
"""
|
||||
messages: List[ChatMessage] = []
|
||||
if self.system_message:
|
||||
messages.append({"role": "system", "content": self.system_message})
|
||||
user_msg: ChatMessage = {"role": "user", "content": content}
|
||||
messages.append(user_msg)
|
||||
|
||||
async def invoke(self, content, volatile=False, local=True):
|
||||
messages = self.long_term_memory.copy()
|
||||
if volatile:
|
||||
messages.extend(self.volatile_memory)
|
||||
else:
|
||||
messages.extend(self.short_term_memory)
|
||||
messages.append(HumanMessage(content=content))
|
||||
try:
|
||||
response = await self.model_pool.invoke(messages)
|
||||
if volatile:
|
||||
self.volatile_memory.extend([HumanMessage(content=content), response])
|
||||
else:
|
||||
self.short_term_memory.extend([HumanMessage(content=content), response])
|
||||
local_ = not volatile and local
|
||||
self.log(HumanMessage(content=content), local=local_)
|
||||
self.log(response, local=local_)
|
||||
return response.content.strip()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
print(f"[Error] Error occurred while invoking LLM: {e}")
|
||||
text, logits = await self.model_pool.invoke(messages)
|
||||
assistant_msg: ChatMessage = {"role": "assistant", "content": text}
|
||||
self.log(user_msg)
|
||||
self.log(assistant_msg)
|
||||
return text.strip(), logits
|
||||
except Exception as e:
|
||||
print(f"[Error] invoke failed: {e}")
|
||||
return None, None
|
||||
|
||||
@@ -1,72 +1,630 @@
|
||||
"""
|
||||
Data loader for personalized ICL experiments.
|
||||
|
||||
Loads a target user's test split and example data from all other users.
|
||||
Expects the following directory structure:
|
||||
|
||||
<path>/
|
||||
info.json # dataset metadata (task, classes, features)
|
||||
<user_id>/
|
||||
1/ # examples (train split) - HuggingFace dataset on disk
|
||||
2/ # test split - HuggingFace dataset on disk
|
||||
<other_user>/
|
||||
1/
|
||||
...
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
Integrating a new dataset
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
The pipeline expects:
|
||||
|
||||
1) Directory layout (for DataLoader / ShuffledDataLoader):
|
||||
- <data_path>/info.json
|
||||
- <data_path>/<user_id>/1/ and <data_path>/<user_id>/2/ (HuggingFace datasets saved with datasets.Dataset.save_to_disk)
|
||||
|
||||
2) info.json schema:
|
||||
- "task": str (e.g. "Sleep stage classification from EEG")
|
||||
- "class": dict mapping label -> description (e.g. {"W": "Wake", "N1": "Non-REM 1", ...})
|
||||
- "feature": str (sensor/feature description for the LLM)
|
||||
|
||||
3) Each sample (in test and example datasets) must be a dict with:
|
||||
- "label": str (class name, must be one of the keys in info.json["class"])
|
||||
- "features": dict of str -> value (e.g. {"Fpz-Cz": "0.12, -0.05, ...", "Pz-Oz": "..."}); values are formatted for the prompt
|
||||
|
||||
Option A - Convert your dataset to this format:
|
||||
- Create info.json and per-user dirs; build HuggingFace Dataset with columns "label" and "features", then save_to_disk to <user_id>/1 and <user_id>/2.
|
||||
- Use prepare_dataset_for_sc() in this module to write from in-memory lists.
|
||||
|
||||
Option B - Use in-memory data without changing directory layout:
|
||||
- Use InMemoryDataLoader(metadata, test_samples, example_samples) which implements the same interface as DataLoader.
|
||||
- Use it in run_sc by constructing it and passing to run_single_task; for run_confidence_based / run_consistency_based / run_sc_queue_random you currently need the directory layout (or extend those runners to accept a loader instance).
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import datasets
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""Loads a target user's test set and ICL examples from other users."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
user_id,
|
||||
example_pool="out",
|
||||
continuous=True,
|
||||
):
|
||||
if not os.path.exists(os.path.join(data_path, "info.json")):
|
||||
return
|
||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
|
||||
return
|
||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
|
||||
return
|
||||
|
||||
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
|
||||
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
|
||||
self.example_dataset = datasets.Dataset.from_list([])
|
||||
users = glob(os.path.join(data_path, "*"))
|
||||
users = [path.split("/")[-1] for path in users]
|
||||
if "info.json" in users:
|
||||
users.remove("info.json")
|
||||
for user in users:
|
||||
if example_pool == "out" and user == user_id:
|
||||
continue
|
||||
if example_pool == "in" and user != user_id:
|
||||
continue
|
||||
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
|
||||
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
data_path: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Support both path= and data_path= for compatibility with run_sc
|
||||
root = path or data_path
|
||||
if root is None or user_id is None:
|
||||
raise ValueError("path (or data_path) and user_id are required")
|
||||
metadata_path = os.path.join(root, "info.json")
|
||||
target_user_path = os.path.join(root, user_id, "2")
|
||||
|
||||
if not continuous:
|
||||
self.test_dataset = self.test_dataset.shuffle(seed=0)
|
||||
self.example_dataset = self.example_dataset.shuffle(seed=0)
|
||||
if not os.path.exists(metadata_path) or not os.path.exists(target_user_path):
|
||||
return
|
||||
|
||||
def __len__(self):
|
||||
# Metadata
|
||||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||
self.metadata: Dict[str, Any] = json.load(f)
|
||||
|
||||
# Test set for the target user
|
||||
self.test_dataset = datasets.load_from_disk(target_user_path)
|
||||
|
||||
# Example set: train splits from all *other* users
|
||||
example_paths = [
|
||||
os.path.join(user_path, "1")
|
||||
for user_path in glob(os.path.join(root, "*"))
|
||||
if os.path.isdir(user_path) and os.path.basename(user_path) != user_id
|
||||
]
|
||||
example_parts = [datasets.load_from_disk(p) for p in example_paths]
|
||||
if example_parts:
|
||||
self.example_dataset = datasets.concatenate_datasets(example_parts)
|
||||
else:
|
||||
self.example_dataset = datasets.Dataset.from_list([])
|
||||
|
||||
# Shuffle
|
||||
if shuffle:
|
||||
self.test_dataset = self.test_dataset.shuffle(seed=seed)
|
||||
self.example_dataset = self.example_dataset.shuffle(seed=seed)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.test_dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.test_dataset[idx]
|
||||
return sample
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
return self.test_dataset[idx]
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.test_dataset:
|
||||
yield sample
|
||||
yield from self.test_dataset
|
||||
|
||||
def get_examples(self):
|
||||
def get_examples(self) -> datasets.Dataset:
|
||||
return self.example_dataset
|
||||
|
||||
def get_metadata(self):
|
||||
def get_metadata(self) -> Dict[str, Any]:
|
||||
return self.metadata
|
||||
|
||||
def get_sensor_info(self):
|
||||
def get_sensor_info(self) -> str:
|
||||
return self.metadata["feature"]
|
||||
|
||||
def get_task_info(self):
|
||||
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
|
||||
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
|
||||
classes_info = "\n".join(classes_info)
|
||||
task_info += f"**Classes**:\n{classes_info}"
|
||||
return task_info
|
||||
def get_task_info(self) -> str:
|
||||
"""Return a formatted string describing the task and its classes."""
|
||||
classes_info = "\n".join(
|
||||
f" - {k}: {v}" for k, v in self.metadata["class"].items()
|
||||
)
|
||||
return f"**Task**:\n{self.metadata['task']}\n\n**Classes**:\n{classes_info}"
|
||||
|
||||
def get_classes_info(self):
|
||||
classes_info = [k for k in self.metadata["class"].keys()]
|
||||
return classes_info
|
||||
def get_classes_info(self) -> List[str]:
|
||||
return list(self.metadata["class"].keys())
|
||||
|
||||
|
||||
class InMemoryDataLoader:
|
||||
"""
|
||||
DataLoader interface backed by in-memory lists. Use this to run the SC
|
||||
pipeline on a new dataset without writing the directory layout to disk.
|
||||
|
||||
Implements the same interface as DataLoader: __len__, __getitem__, __iter__,
|
||||
get_examples(), get_metadata(), get_sensor_info(), get_task_info(), get_classes_info().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata: Dict[str, Any],
|
||||
test_samples: List[Dict[str, Any]],
|
||||
example_samples: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
metadata: Must have "task", "class" (dict label -> description), "feature" (str).
|
||||
test_samples: List of dicts with "label" and "features" (dict).
|
||||
example_samples: Same schema; used as ICL example pool.
|
||||
"""
|
||||
self.metadata = metadata
|
||||
self.test_dataset = datasets.Dataset.from_list(test_samples)
|
||||
self.example_dataset = datasets.Dataset.from_list(example_samples)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.test_dataset)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
return self.test_dataset[idx]
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.test_dataset
|
||||
|
||||
def get_examples(self) -> datasets.Dataset:
|
||||
return self.example_dataset
|
||||
|
||||
def get_metadata(self) -> Dict[str, Any]:
|
||||
return self.metadata
|
||||
|
||||
def get_sensor_info(self) -> str:
|
||||
return self.metadata["feature"]
|
||||
|
||||
def get_task_info(self) -> str:
|
||||
classes_info = "\n".join(
|
||||
f" - {k}: {v}" for k, v in self.metadata["class"].items()
|
||||
)
|
||||
return f"**Task**:\n{self.metadata['task']}\n\n**Classes**:\n{classes_info}"
|
||||
|
||||
def get_classes_info(self) -> List[str]:
|
||||
return list(self.metadata["class"].keys())
|
||||
|
||||
|
||||
def prepare_dataset_for_sc(
|
||||
output_path: str,
|
||||
metadata: Dict[str, Any],
|
||||
per_user_splits: Dict[str, Dict[str, List[Dict[str, Any]]]],
|
||||
) -> None:
|
||||
"""
|
||||
Write a new dataset to the directory layout expected by DataLoader and
|
||||
ShuffledDataLoader. Call this once; then point config data_path to output_path.
|
||||
|
||||
Args:
|
||||
output_path: Root directory to create (e.g. /path/to/MyDataset).
|
||||
metadata: info.json contents: "task", "class" (dict), "feature" (str).
|
||||
per_user_splits: { user_id: { "train": [samples], "test": [samples] } }.
|
||||
Each sample is a dict with "label" and "features".
|
||||
"""
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
with open(os.path.join(output_path, "info.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
for user_id, splits in per_user_splits.items():
|
||||
user_dir = os.path.join(output_path, str(user_id))
|
||||
os.makedirs(os.path.join(user_dir, "1"), exist_ok=True)
|
||||
os.makedirs(os.path.join(user_dir, "2"), exist_ok=True)
|
||||
train_ds = datasets.Dataset.from_list(splits.get("train", []))
|
||||
test_ds = datasets.Dataset.from_list(splits.get("test", []))
|
||||
train_ds.save_to_disk(os.path.join(user_dir, "1"))
|
||||
test_ds.save_to_disk(os.path.join(user_dir, "2"))
|
||||
|
||||
|
||||
# PPGBP data loader
|
||||
|
||||
"""
|
||||
PPGBP Dataset Loader
|
||||
|
||||
A data loader/iterator for the PPG-BP dataset that reads xlsx metadata
|
||||
and corresponding signal text files, with optional batching support.
|
||||
|
||||
Metadata embeddings live under PPGBP_METADATA_EMBEDDINGS_ROOT (one .npy per
|
||||
subject ID). If missing, they are generated automatically in PPGBPLoader.__init__
|
||||
using SBERT. Set PPGBP_METADATA_EMBEDDINGS_ROOT in .env or environment.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Iterator, Tuple, Dict, List, Optional, Union
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
_ppgbp_env_loaded = False
|
||||
def _ensure_ppgbp_env():
|
||||
global _ppgbp_env_loaded
|
||||
if not _ppgbp_env_loaded:
|
||||
load_dotenv()
|
||||
_ppgbp_env_loaded = True
|
||||
except ImportError:
|
||||
def _ensure_ppgbp_env():
|
||||
pass
|
||||
|
||||
|
||||
def _get_ppgbp_embedding_root() -> str:
|
||||
_ensure_ppgbp_env()
|
||||
root = os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT")
|
||||
if not root:
|
||||
raise ValueError(
|
||||
"PPGBP_METADATA_EMBEDDINGS_ROOT is not set. Set it in .env or environment "
|
||||
"(e.g. PPGBP_METADATA_EMBEDDINGS_ROOT=/path/to/metadata_embeddings)."
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def _get_ppgbp_embedding_root_optional() -> Optional[str]:
|
||||
"""Return PPGBP_METADATA_EMBEDDINGS_ROOT if set, else None. Used when embeddings are not required (e.g. recruiter with feature_root)."""
|
||||
_ensure_ppgbp_env()
|
||||
root = os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT") or ""
|
||||
return root.strip() or None
|
||||
|
||||
|
||||
def _get_missing_embedding_ids(embedding_root: str, subject_ids: np.ndarray) -> Optional[List[int]]:
|
||||
"""
|
||||
Return list of subject_ids that do not have a .npy file in embedding_root.
|
||||
Return None if embedding_root is not a directory (so caller can create and generate).
|
||||
"""
|
||||
if not os.path.isdir(embedding_root):
|
||||
return None
|
||||
missing = []
|
||||
for sid in subject_ids:
|
||||
path = os.path.join(embedding_root, f"{int(sid)}.npy")
|
||||
if not os.path.isfile(path):
|
||||
missing.append(int(sid))
|
||||
return missing if missing else []
|
||||
|
||||
|
||||
def _check_embedding_root_populated(embedding_root: str, subject_ids: np.ndarray) -> None:
|
||||
"""Raise if embedding_root is missing or does not contain .npy for every subject_id."""
|
||||
missing_or_none = _get_missing_embedding_ids(embedding_root, subject_ids)
|
||||
if missing_or_none is None:
|
||||
raise FileNotFoundError(
|
||||
f"Metadata embeddings root is not a directory or does not exist: {embedding_root}."
|
||||
)
|
||||
if missing_or_none:
|
||||
raise FileNotFoundError(
|
||||
f"Metadata embeddings root {embedding_root} is missing .npy files for subject IDs: "
|
||||
f"{missing_or_none[:10]}{'...' if len(missing_or_none) > 10 else ''}."
|
||||
)
|
||||
|
||||
|
||||
def _generate_and_save_ppgbp_metadata_embeddings(
|
||||
embedding_root: str,
|
||||
subject_path: str,
|
||||
subject_ids: np.ndarray,
|
||||
) -> None:
|
||||
"""
|
||||
Generate SBERT metadata embeddings for the given subject IDs and save as
|
||||
{subject_id}.npy under embedding_root. Uses tqdm for progress.
|
||||
"""
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
tqdm = lambda x, **kw: x # noqa: E731
|
||||
|
||||
# Lazy import to keep SBERT/gen_plot deps out of normal import path
|
||||
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if _repo_root not in sys.path:
|
||||
sys.path.insert(0, _repo_root)
|
||||
from analysis.user_similarity.sbert_metadata_ppgbp.gen_plot import (
|
||||
load_subject_metadata,
|
||||
SBERT_Metadata,
|
||||
_normalize_sex_for_sbert,
|
||||
_normalize_hypertension_for_sbert,
|
||||
)
|
||||
|
||||
os.makedirs(embedding_root, exist_ok=True)
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
embedder = SBERT_Metadata()
|
||||
|
||||
ids_to_generate = [int(sid) for sid in subject_ids]
|
||||
for user_id in tqdm(ids_to_generate, desc="PPGBP metadata embeddings", unit="subject"):
|
||||
user_id_str = str(user_id)
|
||||
meta = subject_metadata.get(user_id_str, {})
|
||||
sex = _normalize_sex_for_sbert(meta.get("sex"))
|
||||
age = meta.get("age")
|
||||
height = meta.get("height")
|
||||
weight = meta.get("weight")
|
||||
sbp = meta.get("sbp")
|
||||
dbp = meta.get("dbp")
|
||||
hr = meta.get("hr")
|
||||
bmi = meta.get("bmi")
|
||||
hypertension = _normalize_hypertension_for_sbert(meta.get("hypertension"))
|
||||
if age is not None and isinstance(age, float) and np.isnan(age):
|
||||
age = None
|
||||
|
||||
emb = embedder.compute_embedding_from_metadata(
|
||||
[sex], [age], [height], [weight], [sbp], [dbp], [hr], [bmi], [hypertension]
|
||||
)
|
||||
path = os.path.join(embedding_root, f"{user_id}.npy")
|
||||
np.save(path, emb[0])
|
||||
|
||||
|
||||
def load_ppgbp_metadata_embedding(embedding_root: str, subject_id: int) -> np.ndarray:
|
||||
"""Load a single metadata embedding for subject_id from embedding_root (utility)."""
|
||||
path = os.path.join(embedding_root, f"{int(subject_id)}.npy")
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Metadata embedding not found: {path}")
|
||||
return np.load(path)
|
||||
|
||||
|
||||
def save_ppgbp_metadata_embeddings(
|
||||
embedding_root: str,
|
||||
subject_path: str,
|
||||
subject_ids: Optional[np.ndarray] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate and save PPGBP metadata embeddings under embedding_root (one .npy per subject).
|
||||
If subject_ids is None, uses all subject IDs present in the metadata file at subject_path.
|
||||
Uses tqdm for progress.
|
||||
"""
|
||||
if subject_ids is not None and len(subject_ids) == 0:
|
||||
return
|
||||
# If subject_ids is None we need to get all from metadata; _generate_* expects an array
|
||||
if subject_ids is None:
|
||||
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if _repo_root not in sys.path:
|
||||
sys.path.insert(0, _repo_root)
|
||||
from analysis.user_similarity.sbert_metadata_ppgbp.gen_plot import load_subject_metadata
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
subject_ids = np.array(sorted(int(k) for k in subject_metadata.keys()))
|
||||
_generate_and_save_ppgbp_metadata_embeddings(embedding_root, subject_path, subject_ids)
|
||||
|
||||
|
||||
class PPGBPLoader:
|
||||
"""
|
||||
Data loader for the PPG-BP dataset.
|
||||
|
||||
Reads metadata from xlsx file and corresponding PPG signal text files.
|
||||
Supports iteration over single samples or batches.
|
||||
Subject list is built from 0_subject/ and matched with metadata;
|
||||
train/test is 80/20 at subject level (random, fixed by seed).
|
||||
"""
|
||||
|
||||
COLUMN_MAPPING = {
|
||||
"Sex(M/F)": "sex",
|
||||
"Age(year)": "age",
|
||||
"Systolic Blood Pressure(mmHg)": "sysbp",
|
||||
"Diastolic Blood Pressure(mmHg)": "diasbp",
|
||||
"Heart Rate(b/m)": "hr",
|
||||
"BMI(kg/m^2)": "bmi"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: str,
|
||||
split: str = 'all',
|
||||
batch_size: Optional[int] = None,
|
||||
shuffle: bool = False,
|
||||
num_segments: int = 3,
|
||||
seed: int = 42,
|
||||
return_metadata_embeddings: bool = False,
|
||||
):
|
||||
self.base_dir = base_dir
|
||||
self.split = split
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.num_segments = num_segments
|
||||
self.seed = seed
|
||||
self.return_metadata_embeddings = return_metadata_embeddings
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
self.signal_dir = os.path.join(base_dir, "0_subject")
|
||||
self.xlsx_path = self._find_xlsx_file()
|
||||
|
||||
self.metadata = self._load_metadata()
|
||||
self._all_subject_ids = self._get_unified_subject_ids()
|
||||
self._train_ids, self._test_ids = self._get_train_test_split(self._all_subject_ids)
|
||||
self.subject_ids = self._get_split_ids()
|
||||
self.metadata = self.metadata[self.metadata['subject_ID'].isin(self.subject_ids)]
|
||||
|
||||
self._embedding_root = _get_ppgbp_embedding_root_optional()
|
||||
if self._embedding_root is not None:
|
||||
missing = _get_missing_embedding_ids(self._embedding_root, self._all_subject_ids)
|
||||
if len(self._all_subject_ids) > 0 and (missing is None or len(missing) > 0):
|
||||
_generate_and_save_ppgbp_metadata_embeddings(
|
||||
self._embedding_root, self.xlsx_path, self._all_subject_ids
|
||||
)
|
||||
_check_embedding_root_populated(self._embedding_root, self.subject_ids)
|
||||
|
||||
self._current_idx = 0
|
||||
self._indices = np.arange(len(self.subject_ids))
|
||||
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(self._indices)
|
||||
|
||||
def _find_xlsx_file(self) -> str:
|
||||
for f in os.listdir(self.base_dir):
|
||||
if f.endswith('.xlsx'):
|
||||
return os.path.join(self.base_dir, f)
|
||||
raise FileNotFoundError(f"No xlsx file found in {self.base_dir}")
|
||||
|
||||
def _load_metadata(self) -> pd.DataFrame:
|
||||
# Use header=1: xlsx has a title row, then row with column names (subject_ID, etc.)
|
||||
df = pd.read_excel(self.xlsx_path, header=1)
|
||||
df = df.rename(columns=self.COLUMN_MAPPING)
|
||||
df = df.fillna(0)
|
||||
return df
|
||||
|
||||
def _get_unified_subject_ids(self) -> np.ndarray:
|
||||
"""Subject IDs present in both 0_subject/ (from listdir) and metadata."""
|
||||
if not os.path.isdir(self.signal_dir):
|
||||
return np.array([], dtype=np.int64)
|
||||
# Files are named {subject_id}_{segment}.txt
|
||||
ids_from_files = set()
|
||||
for f in os.listdir(self.signal_dir):
|
||||
if f.endswith(".txt"):
|
||||
try:
|
||||
sid = int(f.replace(".txt", "").split("_")[0])
|
||||
ids_from_files.add(sid)
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
meta_ids = set(self.metadata["subject_ID"].astype(int).values)
|
||||
unified = sorted(ids_from_files & meta_ids)
|
||||
return np.array(unified)
|
||||
|
||||
def _get_train_test_split(self, ids: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""80% train, 20% test at subject level; shuffle fixed by seed."""
|
||||
if len(ids) == 0:
|
||||
return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
|
||||
perm = self.rng.permutation(len(ids))
|
||||
ids_perm = ids[perm]
|
||||
n_train = max(1, int(round(0.8 * len(ids))))
|
||||
train_ids = ids_perm[:n_train]
|
||||
test_ids = ids_perm[n_train:]
|
||||
return train_ids, test_ids
|
||||
|
||||
def _get_split_ids(self) -> np.ndarray:
|
||||
if self.split == "train":
|
||||
return self._train_ids
|
||||
elif self.split == "test":
|
||||
return self._test_ids
|
||||
elif self.split == "all":
|
||||
return self._all_subject_ids
|
||||
else:
|
||||
raise ValueError(f"Unknown split: {self.split}. Use 'train', 'test', or 'all'.")
|
||||
|
||||
@property
|
||||
def train_ids(self) -> np.ndarray:
|
||||
"""Subject IDs in the train split (80%)."""
|
||||
return self._train_ids
|
||||
|
||||
@property
|
||||
def test_ids(self) -> np.ndarray:
|
||||
"""Subject IDs in the test split (20%)."""
|
||||
return self._test_ids
|
||||
|
||||
def _load_signal(self, subject_id: int) -> np.ndarray:
|
||||
segments = []
|
||||
for s in range(1, self.num_segments + 1):
|
||||
filepath = os.path.join(self.signal_dir, f"{subject_id}_{s}.txt")
|
||||
signal = pd.read_csv(filepath, sep='\t', header=None)
|
||||
signal = signal.values.squeeze()
|
||||
if len(signal) > 1:
|
||||
signal = signal[:-1]
|
||||
segments.append(signal)
|
||||
return np.array(segments, dtype=object)
|
||||
|
||||
def _get_metadata_dict(self, subject_id: int) -> Dict:
|
||||
row = self.metadata[self.metadata['subject_ID'] == subject_id].iloc[0]
|
||||
return row.to_dict()
|
||||
|
||||
def _load_metadata_embedding(self, subject_id: int) -> np.ndarray:
|
||||
return load_ppgbp_metadata_embedding(self._embedding_root, subject_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.subject_ids)
|
||||
|
||||
def __iter__(self) -> 'PPGBPLoader':
|
||||
self._current_idx = 0
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(self._indices)
|
||||
return self
|
||||
|
||||
def __next__(self) -> Union[
|
||||
Tuple[Dict, np.ndarray],
|
||||
Tuple[Dict, np.ndarray, np.ndarray],
|
||||
Tuple[List[Dict], List[np.ndarray]],
|
||||
Tuple[List[Dict], List[np.ndarray], List[np.ndarray]],
|
||||
]:
|
||||
if self._current_idx >= len(self.subject_ids):
|
||||
raise StopIteration
|
||||
|
||||
if self.batch_size is None:
|
||||
idx = self._indices[self._current_idx]
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata = self._get_metadata_dict(subject_id)
|
||||
signal = self._load_signal(subject_id)
|
||||
self._current_idx += 1
|
||||
if self.return_metadata_embeddings and self._embedding_root is not None:
|
||||
return metadata, signal, self._load_metadata_embedding(subject_id)
|
||||
return metadata, signal
|
||||
else:
|
||||
end_idx = min(self._current_idx + self.batch_size, len(self.subject_ids))
|
||||
batch_indices = self._indices[self._current_idx:end_idx]
|
||||
|
||||
metadata_list = []
|
||||
signal_list = []
|
||||
embedding_list = [] if (self.return_metadata_embeddings and self._embedding_root is not None) else None
|
||||
|
||||
for idx in batch_indices:
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata_list.append(self._get_metadata_dict(subject_id))
|
||||
signal_list.append(self._load_signal(subject_id))
|
||||
if self.return_metadata_embeddings and self._embedding_root is not None:
|
||||
embedding_list.append(self._load_metadata_embedding(subject_id))
|
||||
|
||||
self._current_idx = end_idx
|
||||
if embedding_list is not None:
|
||||
return metadata_list, signal_list, embedding_list
|
||||
return metadata_list, signal_list
|
||||
|
||||
def __getitem__(self, idx: int) -> Union[Tuple[Dict, np.ndarray], Tuple[Dict, np.ndarray, np.ndarray]]:
|
||||
if idx < 0 or idx >= len(self.subject_ids):
|
||||
raise IndexError(f"Index {idx} out of range")
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata = self._get_metadata_dict(subject_id)
|
||||
signal = self._load_signal(subject_id)
|
||||
if self.return_metadata_embeddings and self._embedding_root is not None:
|
||||
return metadata, signal, self._load_metadata_embedding(subject_id)
|
||||
return metadata, signal
|
||||
|
||||
def get_by_subject_id(self, subject_id: int) -> Union[Tuple[Dict, np.ndarray], Tuple[Dict, np.ndarray, np.ndarray]]:
|
||||
if subject_id not in self.subject_ids:
|
||||
raise ValueError(f"Subject ID {subject_id} not found in this split.")
|
||||
metadata = self._get_metadata_dict(subject_id)
|
||||
signal = self._load_signal(subject_id)
|
||||
if self.return_metadata_embeddings and self._embedding_root is not None:
|
||||
return metadata, signal, self._load_metadata_embedding(subject_id)
|
||||
return metadata, signal
|
||||
|
||||
def reset(self):
|
||||
self._current_idx = 0
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(self._indices)
|
||||
|
||||
def iter_batches(
|
||||
self, batch_size: int
|
||||
) -> Iterator[Union[Tuple[List[Dict], List[np.ndarray]], Tuple[List[Dict], List[np.ndarray], List[np.ndarray]]]]:
|
||||
indices = np.arange(len(self.subject_ids))
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(indices)
|
||||
|
||||
for start_idx in range(0, len(self.subject_ids), batch_size):
|
||||
end_idx = min(start_idx + batch_size, len(self.subject_ids))
|
||||
batch_indices = indices[start_idx:end_idx]
|
||||
|
||||
metadata_list = []
|
||||
signal_list = []
|
||||
embedding_list = [] if (self.return_metadata_embeddings and self._embedding_root is not None) else None
|
||||
|
||||
for idx in batch_indices:
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata_list.append(self._get_metadata_dict(subject_id))
|
||||
signal_list.append(self._load_signal(subject_id))
|
||||
if self.return_metadata_embeddings and self._embedding_root is not None:
|
||||
embedding_list.append(self._load_metadata_embedding(subject_id))
|
||||
|
||||
if embedding_list is not None:
|
||||
yield metadata_list, signal_list, embedding_list
|
||||
else:
|
||||
yield metadata_list, signal_list
|
||||
|
||||
|
||||
def get_loaders(
|
||||
base_dir: str,
|
||||
batch_size: Optional[int] = None,
|
||||
shuffle_train: bool = True,
|
||||
seed: int = 42
|
||||
) -> Tuple[PPGBPLoader, PPGBPLoader]:
|
||||
"""Convenience function to get train and test loaders (80/20 split)."""
|
||||
train_loader = PPGBPLoader(
|
||||
base_dir=base_dir, split="train",
|
||||
batch_size=batch_size, shuffle=shuffle_train, seed=seed
|
||||
)
|
||||
test_loader = PPGBPLoader(
|
||||
base_dir=base_dir, split="test",
|
||||
batch_size=batch_size, shuffle=False, seed=seed
|
||||
)
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("PPGBPLoader ready to use.")
|
||||
|
||||
@@ -151,4 +151,38 @@ class Queue:
|
||||
"avg_survival": sum(durations) / len(durations),
|
||||
"max_survival": max(durations),
|
||||
"avg_usage": sum(usages) / len(usages) if usages else 0
|
||||
}
|
||||
}
|
||||
|
||||
def update_with_recruiter(self, results, new_example_sets, L: int = 1):
|
||||
"""
|
||||
Evict L lowest-scoring sets (by self-certainty from results) and push L new sets.
|
||||
results: list of (case, response_dict_or_text, self_certainty_score) in queue order.
|
||||
new_example_sets: list of L new cases (each case = list of indices, same as queue items).
|
||||
Lower self_certainty = better; we evict the L with highest (worst) score.
|
||||
"""
|
||||
if not results or L <= 0:
|
||||
return
|
||||
current = list(self._queue)
|
||||
if len(current) == 0:
|
||||
for case in new_example_sets[: self._queue.maxlen]:
|
||||
self._queue.append(list(case))
|
||||
self._register_stats(case)
|
||||
return
|
||||
scores = []
|
||||
for i, item in enumerate(current):
|
||||
sc = -1.0
|
||||
if i < len(results) and len(results[i]) >= 3:
|
||||
sc = float(results[i][2])
|
||||
scores.append((i, sc))
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
to_evict_idx = {scores[j][0] for j in range(min(L, len(scores)))}
|
||||
new_queue = [current[i] for i in range(len(current)) if i not in to_evict_idx]
|
||||
for idx in sorted(to_evict_idx, reverse=True):
|
||||
if idx < len(current):
|
||||
self._record_eviction_stats(current[idx])
|
||||
for case in new_example_sets[:L]:
|
||||
if len(new_queue) >= self._queue.maxlen:
|
||||
break
|
||||
new_queue.append(list(case))
|
||||
self._register_stats(case)
|
||||
self._queue = deque(new_queue, maxlen=self._queue.maxlen)
|
||||
263
sc/core/model.py
263
sc/core/model.py
@@ -1,153 +1,160 @@
|
||||
import os
|
||||
"""
|
||||
HuggingFace Causal LM wrapper for inference with logits.
|
||||
|
||||
Loads a model via transformers, generates text, and returns both
|
||||
the generated text and full logits via a forward pass.
|
||||
|
||||
Message format: [{"role": "user"|"assistant"|"system", "content": str}, ...]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import requests
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_together import ChatTogether
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# {"role": "user"|"assistant"|"system", "content": str}
|
||||
ChatMessage = Dict[str, str]
|
||||
|
||||
|
||||
def load_models(models, temperature=0.0, num_ctx=15000):
|
||||
model_pool = AsyncModelPool()
|
||||
for model in models:
|
||||
model_pool.add_model(Model(model, temperature=temperature, num_ctx=num_ctx))
|
||||
model_pool.init_models()
|
||||
return model_pool
|
||||
# -----------------------------------------------------------------------------
|
||||
# Model
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, model, temperature, num_ctx):
|
||||
if model.startswith("ollama:"):
|
||||
model = model.replace("ollama:", "")
|
||||
if "url:" in model: # custom parsing for local ollama models
|
||||
model = model.replace("url:", "")
|
||||
base_url = model.split("/")[0]
|
||||
model_type = model.split("/")[1]
|
||||
# self.model = ChatOllama(
|
||||
# model=model_type,
|
||||
# base_url=f"http://{base_url}",
|
||||
# temperature=temperature,
|
||||
# num_ctx=num_ctx,
|
||||
# )
|
||||
self.model = None
|
||||
self.base_url = f"http://{base_url}/api/chat"
|
||||
self.model_type = model_type
|
||||
self.temperature = temperature
|
||||
self.num_ctx = num_ctx
|
||||
else:
|
||||
self.model = ChatOllama(
|
||||
model=model.replace("ollama:", ""),
|
||||
temperature=temperature,
|
||||
num_ctx=num_ctx,
|
||||
)
|
||||
elif model.startswith("together"):
|
||||
if "TOGETHER_API_KEY" not in os.environ:
|
||||
print("[!] TOGETHER_API_KEY is not set")
|
||||
assert 0
|
||||
self.model = ChatTogether(
|
||||
model=model.replace("together:", ""),
|
||||
temperature=temperature,
|
||||
max_tokens=num_ctx,
|
||||
max_retries=3,
|
||||
)
|
||||
elif model.startswith("openai"):
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
print("[!] OPENAI_API_KEY is not set")
|
||||
assert 0
|
||||
self.model = ChatOpenAI(
|
||||
model=model.replace("openai:", ""),
|
||||
temperature=temperature,
|
||||
)
|
||||
else:
|
||||
self.model = init_chat_model(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
"""HuggingFace causal LM. Returns generated text + full logits."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
temperature: float = 0.0,
|
||||
max_new_tokens: int = 1024,
|
||||
enable_thinking: bool = False,
|
||||
) -> None:
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
||||
self._model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
).eval()
|
||||
self.temperature = temperature
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.enable_thinking = enable_thinking
|
||||
|
||||
def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
|
||||
"""
|
||||
Generate a reply and return (text, logits).
|
||||
|
||||
Returns:
|
||||
text: Generated text string.
|
||||
logits: Logits tensor (1, num_generated_tokens, vocab_size) on CPU.
|
||||
"""
|
||||
# Build prompt from chat messages
|
||||
prompt = self._tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=self.enable_thinking,
|
||||
)
|
||||
inputs = self._tokenizer(prompt, return_tensors="pt").to(self._model.device)
|
||||
prompt_len = inputs["input_ids"].shape[1]
|
||||
|
||||
# Generate with scores in a single pass
|
||||
with torch.inference_mode():
|
||||
gen_out = self._model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
do_sample=self.temperature > 0,
|
||||
temperature=self.temperature if self.temperature > 0 else 1.0,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
def invoke(self, messages, logprobs=False, top_logprobs=0):
|
||||
try:
|
||||
if self.model:
|
||||
response = self.model.invoke(messages)
|
||||
return response
|
||||
else:
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = msg.type
|
||||
role = "user" if role == "human" else "assistant"
|
||||
content = msg.content
|
||||
converted_messages.append({"role": role, "content": content})
|
||||
response = requests.post(self.base_url, json={
|
||||
"model": self.model_type,
|
||||
"messages": converted_messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": self.temperature,
|
||||
"num_ctx": self.num_ctx,
|
||||
},
|
||||
"logprobs": logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
})
|
||||
response = response.json()
|
||||
resp_msg = AIMessage(content=response["message"]["content"])
|
||||
if logprobs:
|
||||
return resp_msg, response["logprobs"]
|
||||
else:
|
||||
return resp_msg
|
||||
return resp_msg, response["logprobs"]
|
||||
except Exception as e:
|
||||
print(f"[Error] Error occurred while invoking LLM: {e}")
|
||||
return e
|
||||
# Decode generated tokens
|
||||
gen_ids = gen_out.sequences[0][prompt_len:]
|
||||
text = self._tokenizer.decode(gen_ids, skip_special_tokens=True)
|
||||
|
||||
# Strip internal thinking tags (e.g. "analysis...assistantfinal...")
|
||||
if "assistantfinal" in text:
|
||||
text = text.split("assistantfinal", 1)[1].strip()
|
||||
|
||||
# Stack per-step scores into logits: (1, num_generated_tokens, vocab_size)
|
||||
logits = torch.stack(gen_out.scores, dim=1) # scores: tuple of (batch, vocab)
|
||||
|
||||
return text, logits.cpu()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Async wrappers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AsyncModel:
|
||||
def __init__(self, model):
|
||||
"""Runs Model.invoke in a thread executor for async usage."""
|
||||
|
||||
def __init__(self, model: Model) -> None:
|
||||
self.model = model
|
||||
|
||||
async def invoke(self, content, logprobs=False, top_logprobs=0):
|
||||
loop = asyncio.get_event_loop()
|
||||
if logprobs:
|
||||
response, logprobs = await loop.run_in_executor(
|
||||
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.invoke(content, logprobs=logprobs, top_logprobs=top_logprobs),
|
||||
)
|
||||
return response, logprobs
|
||||
else:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.invoke(content),
|
||||
)
|
||||
return response
|
||||
lambda: self.model.invoke(messages),
|
||||
)
|
||||
|
||||
|
||||
class AsyncModelPool:
|
||||
def __init__(self):
|
||||
self.models = []
|
||||
self._available_models = None
|
||||
self._model_semaphore = None
|
||||
"""Round-robin pool of async models."""
|
||||
|
||||
def add_model(self, model):
|
||||
def __init__(self) -> None:
|
||||
self.models: List[Model] = []
|
||||
self._queue: Optional[asyncio.Queue[AsyncModel]] = None
|
||||
|
||||
def add_model(self, model: Model) -> None:
|
||||
self.models.append(model)
|
||||
|
||||
def init_models(self):
|
||||
print(f"Initializing {len(self.models)} models...")
|
||||
self._available_models = asyncio.Queue()
|
||||
def init_models(self) -> None:
|
||||
print(f"Initializing {len(self.models)} model(s)...")
|
||||
self._queue = asyncio.Queue()
|
||||
for model in self.models:
|
||||
async_model = AsyncModel(model)
|
||||
self._available_models.put_nowait(async_model)
|
||||
self._model_semaphore = asyncio.Semaphore(len(self.models))
|
||||
self._queue.put_nowait(AsyncModel(model))
|
||||
|
||||
async def invoke(self, content, logprobs=False, top_logprobs=0):
|
||||
if self._available_models is None:
|
||||
raise RuntimeError("Model pool not initialized. Call init_models() first.")
|
||||
async_model = await self._available_models.get()
|
||||
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
|
||||
"""Acquire a model, invoke, release. Returns (text, logits)."""
|
||||
if self._queue is None:
|
||||
raise RuntimeError("Call init_models() first.")
|
||||
|
||||
async_model = await self._queue.get()
|
||||
try:
|
||||
if logprobs:
|
||||
response, logprobs = await async_model.invoke(content, logprobs=logprobs, top_logprobs=top_logprobs)
|
||||
return response, logprobs
|
||||
else:
|
||||
response = await async_model.invoke(content)
|
||||
return response
|
||||
return await async_model.invoke(messages)
|
||||
finally:
|
||||
self._available_models.put_nowait(async_model)
|
||||
self._queue.put_nowait(async_model)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Public API
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_models(
|
||||
models: List[str],
|
||||
temperature: float = 0.0,
|
||||
max_new_tokens: int = 1024,
|
||||
enable_thinking: bool = False,
|
||||
) -> AsyncModelPool:
|
||||
"""Create and initialize a model pool from a list of model paths."""
|
||||
pool = AsyncModelPool()
|
||||
for path in models:
|
||||
pool.add_model(
|
||||
Model(
|
||||
path,
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
enable_thinking=enable_thinking,
|
||||
)
|
||||
)
|
||||
pool.init_models()
|
||||
return pool
|
||||
|
||||
487
sc/core/recruiter_agent.py
Normal file
487
sc/core/recruiter_agent.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
RecruiterAgent: MAB-based example set selection for personalized ICL.
|
||||
|
||||
Components:
|
||||
- self_certainty(logits): KL-div of LLM output from uniform (reward signal)
|
||||
- borda_vote(results): aggregate answers across example sets
|
||||
- FeatureIndex: loads index.json + embeddings.npy for O(1) lookup
|
||||
- MABSelector: contextual bandit adapted from CASE/sample_case.py top_m_arm
|
||||
- ArmPool: manages candidate example sets and feature vectors
|
||||
- RecruiterAgent: orchestrates everything
|
||||
|
||||
Feature file schema:
|
||||
<feature_root>/index.json + <feature_root>/embeddings.npy
|
||||
|
||||
index.json:
|
||||
{
|
||||
"version": "1.0",
|
||||
"embedding_type": "sbert_metadata",
|
||||
"embedding_dim": 384,
|
||||
"samples": [
|
||||
{"row": 0, "user_id": "001", "sample_id": 42, "label": "Hypertension", "session": "1"},
|
||||
...
|
||||
]
|
||||
}
|
||||
embeddings.npy: float32 array shape (N_samples, embedding_dim)
|
||||
Switch embedding type by pointing to a different feature_root directory.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import collections, json, math, os, random
|
||||
from random import sample as random_sample
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
_TORCH = True
|
||||
except ImportError:
|
||||
_TORCH = False
|
||||
|
||||
|
||||
# ── Utilities ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def self_certainty(logits) -> float:
|
||||
"""KL divergence of LLM logits from uniform. Lower = more certain. Negate for MAB."""
|
||||
if logits is None or not _TORCH:
|
||||
return 0.0
|
||||
logits = logits.squeeze(0).float()
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
score = (-1.0 / logits.shape[-1]) * log_probs.sum(dim=-1) - math.log(logits.shape[-1])
|
||||
return score.mean().item()
|
||||
|
||||
|
||||
def borda_vote(
|
||||
results: List[Tuple[int, Optional[Dict[str, Any]], float]],
|
||||
valid_classes: Optional[List[str]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Borda-count vote over (arm_idx, response_dict, certainty_score).
|
||||
Lower certainty_score = more certain = higher Borda rank.
|
||||
"""
|
||||
valid = [
|
||||
(r[1]["ANSWER"], r[2]) for r in results
|
||||
if r[1] is not None and "ANSWER" in r[1]
|
||||
and (valid_classes is None or r[1]["ANSWER"] in valid_classes)
|
||||
]
|
||||
if not valid:
|
||||
return None
|
||||
sorted_v = sorted(valid, key=lambda x: x[1])
|
||||
M = len(sorted_v)
|
||||
borda: Dict[str, float] = {}
|
||||
for rank, (ans, _) in enumerate(sorted_v):
|
||||
borda[ans] = borda.get(ans, 0.0) + (M - rank)
|
||||
top = max(borda.values())
|
||||
cands = [a for a, s in borda.items() if s == top]
|
||||
if len(cands) == 1:
|
||||
return cands[0]
|
||||
means = {a: float(np.mean([sc for x, sc in valid if x == a])) for a in cands}
|
||||
return min(means, key=means.get)
|
||||
|
||||
|
||||
# ── FeatureIndex ───────────────────────────────────────────────────────────────
|
||||
|
||||
class FeatureIndex:
|
||||
"""Loads embeddings from disk; O(1) lookup by (user_id, sample_id)."""
|
||||
|
||||
def __init__(self, feature_root: str) -> None:
|
||||
idx_p = os.path.join(feature_root, "index.json")
|
||||
emb_p = os.path.join(feature_root, "embeddings.npy")
|
||||
if not os.path.isfile(idx_p):
|
||||
raise FileNotFoundError(f"index.json not found: {idx_p}")
|
||||
if not os.path.isfile(emb_p):
|
||||
raise FileNotFoundError(f"embeddings.npy not found: {emb_p}")
|
||||
with open(idx_p, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
self.embedding_type: str = meta.get("embedding_type", "unknown")
|
||||
self.embedding_dim: int = meta["embedding_dim"]
|
||||
self.embeddings: np.ndarray = np.load(emb_p).astype(np.float32)
|
||||
self.samples: List[Dict[str, Any]] = meta["samples"]
|
||||
self._lookup: Dict[Tuple[str, int], int] = {
|
||||
(str(e["user_id"]), int(e["sample_id"])): int(e["row"])
|
||||
for e in self.samples
|
||||
}
|
||||
print(f"[FeatureIndex] {len(self.embeddings)} embeddings "
|
||||
f"(dim={self.embedding_dim}, type={self.embedding_type})")
|
||||
|
||||
def get(self, user_id: str, sample_id: int) -> Optional[np.ndarray]:
|
||||
row = self._lookup.get((str(user_id), int(sample_id)))
|
||||
return self.embeddings[row] if row is not None else None
|
||||
|
||||
def get_user_embedding(self, user_id: str) -> Optional[np.ndarray]:
|
||||
"""Mean embedding across all samples for a user."""
|
||||
rows = [self.embeddings[e["row"]] for e in self.samples
|
||||
if str(e["user_id"]) == str(user_id)]
|
||||
return np.mean(rows, axis=0).astype(np.float32) if rows else None
|
||||
|
||||
|
||||
# ── MABSelector ────────────────────────────────────────────────────────────────
|
||||
|
||||
class MABSelector:
|
||||
"""
|
||||
Contextual bandit for top-m arm identification.
|
||||
Adapted from CASE / sample_case.py top_m_arm.
|
||||
|
||||
Key differences from original:
|
||||
- Rewards injected via observe(arm, reward) — no LLM calls, no file IO
|
||||
- update() does one rank-1 theta update per call
|
||||
- No initialization() warm-up loop required
|
||||
|
||||
X: np.ndarray shape (embedding_dim, num_arms)
|
||||
m: number of top arms to track (= queue_size)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, X: np.ndarray, m: int,
|
||||
n_das: int = 5, delta: float = 0.05,
|
||||
epsilon: float = 0.11, sigma: float = 0.5,
|
||||
) -> None:
|
||||
assert X.ndim == 2, "X must be (embedding_dim, num_arms)"
|
||||
self.X = X.astype(np.float64)
|
||||
self.N, self.K = X.shape
|
||||
self.m = min(m, self.K)
|
||||
self.n_das = n_das
|
||||
self.delta = delta
|
||||
self.epsilon = epsilon
|
||||
self.sigma = sigma
|
||||
self.arms = list(range(self.K))
|
||||
self._it = -1
|
||||
self._Bdi: Dict[Tuple[int, int], float] = {}
|
||||
self._reset()
|
||||
|
||||
def _reset(self) -> None:
|
||||
self.t = 0
|
||||
self.rewards: List[float] = []
|
||||
self.pulled_arms: List[int] = []
|
||||
self.means = np.zeros(self.K)
|
||||
self.na = np.zeros(self.K)
|
||||
self.B_inv = np.eye(self.N, dtype=np.float64)
|
||||
self.b = np.zeros(self.N, dtype=np.float64)
|
||||
self.theta = np.random.normal(0, 1, size=(1, self.N))
|
||||
self.J: List[int] = []
|
||||
self.notJ: List[int] = list(range(self.K))
|
||||
self.N_t: List[int] = []
|
||||
self.best_arm: Optional[int] = None
|
||||
self.worst_arm: Optional[int] = None
|
||||
self.challenger: Optional[int] = None
|
||||
self.c_t: Optional[int] = None
|
||||
self.patience = 0
|
||||
self.previous_J: List[int] = []
|
||||
|
||||
# ── math helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _mn(self, x: np.ndarray, A: np.ndarray) -> float:
|
||||
x = np.asarray(x).flatten()
|
||||
return float(np.sqrt(max(float(x @ A @ x), 0.0)))
|
||||
|
||||
def _iu(self, A: np.ndarray, x: np.ndarray) -> np.ndarray:
|
||||
"""Rank-1 inverse update: A - A x xT A / (1 + xT A x)."""
|
||||
x = np.asarray(x).flatten()
|
||||
d = 1.0 + self._mn(x, A) ** 2
|
||||
return A - np.outer(A @ x, x @ A) / d
|
||||
|
||||
def _beta(self) -> float:
|
||||
return math.log((math.log(max(self.t, 2)) + 1) / max(self.delta, 1e-12))
|
||||
|
||||
def _gap(self, i: int, j: Optional[int] = None) -> float:
|
||||
xi = self.X[:, i] - (self.X[:, j] if j is not None else 0.0)
|
||||
return float(self.theta.flatten() @ xi)
|
||||
|
||||
def _var(self, i: int, j: Optional[int] = None) -> float:
|
||||
S = (self.sigma ** 2) * self.B_inv
|
||||
xi = self.X[:, i] - (self.X[:, j] if j is not None else 0.0)
|
||||
return self._mn(xi, S) * math.sqrt(2 * self._beta())
|
||||
|
||||
def _Bij(self, i: int, j: int) -> float:
|
||||
t = max(self.t, 1)
|
||||
if t != self._it:
|
||||
self._Bdi = {}
|
||||
self._it = t
|
||||
k = (i, j)
|
||||
if k not in self._Bdi:
|
||||
self._Bdi[k] = self._gap(i, j) + self._var(i, j)
|
||||
return self._Bdi[k]
|
||||
|
||||
def _randf(self, x: List[float], f) -> int:
|
||||
arr = np.array(x, dtype=float)
|
||||
val = f(arr)
|
||||
c = np.argwhere(arr == val).flatten().tolist()
|
||||
return random_sample(c, 1)[0]
|
||||
|
||||
def _mmax(self, x: List[float], m: int) -> List[int]:
|
||||
x = list(x)
|
||||
ids = []
|
||||
for _ in range(min(m, len(x))):
|
||||
idx = self._randf(x, np.max)
|
||||
ids.append(idx)
|
||||
x[idx] = -float("inf")
|
||||
return ids
|
||||
|
||||
# ── sample step ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _sel(self, nt: Optional[int]) -> Optional[int]:
|
||||
if nt is None or self.c_t is None:
|
||||
return None
|
||||
if self.means[self.c_t] > self.means[nt]:
|
||||
s = self.c_t
|
||||
if s in self.N_t:
|
||||
self.N_t.remove(s)
|
||||
return s
|
||||
return None
|
||||
|
||||
def _Jt(self) -> List[int]:
|
||||
if self.t == 0:
|
||||
return self._mmax(self.means.tolist(), self.m)
|
||||
nt = self.worst_arm
|
||||
sel = self._sel(nt)
|
||||
self.previous_J = list(self.J)
|
||||
if sel is not None:
|
||||
self.J = [a for a in self.J if a != nt] + [sel]
|
||||
if nt is not None and nt not in self.N_t:
|
||||
self.N_t.append(nt)
|
||||
order = np.argsort(self.means[self.J])[::-1].tolist()
|
||||
self.J = [self.J[i] for i in order]
|
||||
self.patience = 0 if self.J != self.previous_J else self.patience + 1
|
||||
return self.J
|
||||
|
||||
def sample(self) -> List[int]:
|
||||
"""Select next arm(s) to evaluate. Returns list of arm indices."""
|
||||
self.J = self._Jt()
|
||||
self.notJ = [a for a in self.arms if a not in self.J]
|
||||
|
||||
if self.t == 0 or not self.N_t:
|
||||
n = min(self.n_das, len(self.notJ))
|
||||
self.N_t = list(np.random.choice(self.notJ, n, replace=False)) if self.notJ else []
|
||||
else:
|
||||
Qt = list(np.random.choice(self.notJ, min(self.n_das, len(self.notJ)), replace=False)) if self.notJ else []
|
||||
self.N_t = list({*self.N_t, *Qt} - set(self.J))
|
||||
if self.N_t:
|
||||
top_n = min(self.n_das, len(self.N_t))
|
||||
tv = np.sort(self.means[self.N_t])[::-1][:top_n]
|
||||
picked: List[int] = []
|
||||
for mv in tv:
|
||||
c = [a for a in self.N_t if self.means[a] == mv and a not in picked]
|
||||
if c:
|
||||
picked.append(random.choice(c))
|
||||
self.N_t = picked
|
||||
|
||||
if self.N_t and self.J:
|
||||
jm = self.means[self.J]
|
||||
ntc = [a for a in self.J if self.means[a] == jm.min()]
|
||||
self.worst_arm = random.choice(ntc)
|
||||
bti = [self._Bij(a, self.worst_arm) for a in self.J]
|
||||
self.best_arm = self.J[self._randf(bti, np.max)]
|
||||
chi = [self._Bij(a, self.best_arm) for a in self.N_t]
|
||||
self.challenger = self.N_t[self._randf(chi, np.max)]
|
||||
nm = self.means[self.N_t]
|
||||
c = [a for a in self.N_t if self.means[a] == nm.max()]
|
||||
self.c_t = random.choice(c)
|
||||
else:
|
||||
self.worst_arm = self.J[0] if self.J else None
|
||||
self.best_arm = self.challenger = self.c_t = None
|
||||
|
||||
# Greedy arm pull
|
||||
if self.best_arm is not None and self.challenger is not None:
|
||||
d = self.X[:, self.best_arm] - self.X[:, self.challenger]
|
||||
u = [self._mn(d, self._iu(self.B_inv, self.X[:, i])) for i in self.arms]
|
||||
return [self.arms[self._randf(u, np.min)]]
|
||||
return [random.choice(self.arms)]
|
||||
|
||||
def observe(self, arm: int, reward: float) -> None:
|
||||
"""Record reward for arm."""
|
||||
self.rewards.append(reward)
|
||||
self.pulled_arms.append(arm)
|
||||
self.na[arm] += 1
|
||||
self.t += 1
|
||||
|
||||
def update(self) -> None:
|
||||
"""Rank-1 theta/B_inv update from the last observation."""
|
||||
if not self.pulled_arms:
|
||||
return
|
||||
x = self.X[:, self.pulled_arms[-1]].flatten()
|
||||
self.B_inv = self._iu(self.B_inv, x)
|
||||
self.b += self.rewards[-1] * x
|
||||
self.theta = (self.B_inv @ self.b).reshape(1, self.N)
|
||||
self.means = (self.theta @ self.X).flatten()
|
||||
|
||||
def recommend(self, n: int) -> List[int]:
|
||||
"""Return n arm indices with highest estimated mean rewards."""
|
||||
return self._mmax(self.means.tolist(), min(n, self.K))
|
||||
|
||||
def stopping_rule(self) -> bool:
|
||||
if None in (self.best_arm, self.worst_arm, self.challenger):
|
||||
return False
|
||||
return round(self._Bij(self.challenger, self.best_arm), 2) <= self.epsilon
|
||||
|
||||
|
||||
# ── ArmPool ────────────────────────────────────────────────────────────────────
|
||||
#[TODO] Add way to sample arms from the larger pool with replacement, instead of fixing on to num_arms arms.
|
||||
class ArmPool:
|
||||
"""
|
||||
Manages candidate example sets (arms).
|
||||
Each arm = list of sample dicts (N-way x k_shot).
|
||||
Arm feature = mean embedding of its constituent samples.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_samples: List[Dict[str, Any]],
|
||||
classes: List[str],
|
||||
k_shot: int,
|
||||
feature_index: Optional[FeatureIndex],
|
||||
num_arms: int,
|
||||
rng: np.random.Generator,
|
||||
) -> None:
|
||||
self.classes = classes
|
||||
self.k_shot = k_shot
|
||||
self.feature_index = feature_index
|
||||
self._rng = rng
|
||||
self._by: Dict[str, List[Dict[str, Any]]] = {c: [] for c in classes}
|
||||
for s in source_samples:
|
||||
lbl = s.get("label")
|
||||
if lbl in self._by:
|
||||
self._by[lbl].append(s)
|
||||
self.arms: List[List[Dict[str, Any]]] = []
|
||||
self.arm_feats: List[Optional[np.ndarray]] = []
|
||||
for _ in range(num_arms):
|
||||
arm, feat = self._make()
|
||||
self.arms.append(arm)
|
||||
self.arm_feats.append(feat)
|
||||
print(f"[ArmPool] {len(self.arms)} arms ({len(classes)}-way {k_shot}-shot)")
|
||||
|
||||
def _make(self) -> Tuple[List[Dict[str, Any]], Optional[np.ndarray]]:
|
||||
s: List[Dict[str, Any]] = []
|
||||
for c in self.classes:
|
||||
pool = self._by.get(c, [])
|
||||
if pool:
|
||||
n = min(self.k_shot, len(pool))
|
||||
idxs = self._rng.choice(len(pool), size=n, replace=False)
|
||||
s.extend([pool[i] for i in idxs])
|
||||
return s, self._feat(s)
|
||||
|
||||
def _feat(self, samples: List[Dict[str, Any]]) -> Optional[np.ndarray]:
|
||||
if self.feature_index is None:
|
||||
return None
|
||||
vecs = []
|
||||
for s in samples:
|
||||
v = self.feature_index.get(
|
||||
str(s.get("user_id", "")),
|
||||
int(s.get("sample_id", s.get("idx", -1))),
|
||||
)
|
||||
if v is not None:
|
||||
vecs.append(v)
|
||||
return np.mean(vecs, axis=0).astype(np.float32) if vecs else None
|
||||
|
||||
def get_arm(self, i: int) -> List[Dict[str, Any]]:
|
||||
return self.arms[i]
|
||||
|
||||
def build_X(self) -> np.ndarray:
|
||||
"""Feature matrix X of shape (embedding_dim, num_arms). Falls back to identity."""
|
||||
feats = [f for f in self.arm_feats if f is not None]
|
||||
if not feats:
|
||||
return np.eye(len(self.arms), dtype=np.float32)
|
||||
X = np.zeros((feats[0].shape[0], len(self.arms)), dtype=np.float32)
|
||||
for i, f in enumerate(self.arm_feats):
|
||||
if f is not None:
|
||||
X[:, i] = f
|
||||
return X
|
||||
|
||||
|
||||
# ── RecruiterAgent ────────────────────────────────────────────────────────────
|
||||
|
||||
class RecruiterAgent:
|
||||
"""
|
||||
MAB-driven recruiter for personalized ICL example set selection.
|
||||
|
||||
Usage:
|
||||
recruiter = RecruiterAgent(source_samples, classes, ...)
|
||||
initial = recruiter.recruit(queue_size) # cold start (random)
|
||||
|
||||
for test_sample in target_stream:
|
||||
results = [(arm_idx, response_dict, score), ...]
|
||||
recruiter.update(results) # MAB update lives here
|
||||
new_sets = recruiter.recruit(L) # MAB-guided new sets
|
||||
queue.update_with_recruiter(results, new_sets)
|
||||
|
||||
MAB update logic inside update():
|
||||
1. reward = -score (negate: lower certainty = better = higher reward)
|
||||
2. mab.observe(arm, reward)
|
||||
3. mab.update() -> rank-1 theta/B_inv update
|
||||
4. mab.sample() -> select next arms to explore (stored for next recruit())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_samples: List[Dict[str, Any]],
|
||||
classes: List[str],
|
||||
feature_index: Optional[FeatureIndex] = None,
|
||||
target_user_id: Optional[str] = None,
|
||||
k_shot: int = 1,
|
||||
queue_size: int = 5,
|
||||
num_arms: int = 50,
|
||||
n_das: int = 5,
|
||||
delta: float = 0.05,
|
||||
epsilon: float = 0.11,
|
||||
sigma: float = 0.5,
|
||||
seed: int = 42,
|
||||
) -> None:
|
||||
self.classes = classes
|
||||
self.k_shot = k_shot
|
||||
self.queue_size = queue_size
|
||||
self.target_user_id = target_user_id
|
||||
|
||||
rng = np.random.default_rng(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
self.pool = ArmPool(source_samples, classes, k_shot, feature_index, num_arms, rng)
|
||||
self.mab = MABSelector(
|
||||
self.pool.build_X(), m=queue_size,
|
||||
n_das=n_das, delta=delta, epsilon=epsilon, sigma=sigma,
|
||||
)
|
||||
self._next: Optional[List[int]] = None
|
||||
self._initialized = False
|
||||
print(f"[RecruiterAgent] {len(self.pool.arms)} arms, "
|
||||
f"queue={queue_size}, classes={list(set(classes))}")
|
||||
|
||||
def recruit(self, n: int) -> List[Tuple[int, List[Dict[str, Any]]]]:
|
||||
"""
|
||||
Return n (arm_idx, example_set) pairs for insertion into the queue.
|
||||
Cold start: random. Subsequent calls: MAB-recommended.
|
||||
"""
|
||||
if not self._initialized:
|
||||
indices = random.sample(range(len(self.pool.arms)), min(n, len(self.pool.arms)))
|
||||
self._initialized = True
|
||||
elif self._next is not None:
|
||||
|
||||
indices = self._next[:n]
|
||||
if len(indices) < n:
|
||||
rest = [i for i in range(len(self.pool.arms)) if i not in indices]
|
||||
indices += random.sample(rest, min(n - len(indices), len(rest)))
|
||||
else:
|
||||
indices = self.mab.recommend(n)
|
||||
return [(i, self.pool.get_arm(i)) for i in indices]
|
||||
|
||||
def update(self, results: List[Tuple[int, Optional[Dict[str, Any]], float]]) -> None:
|
||||
"""
|
||||
Update MAB from (arm_idx, response_dict, self_certainty_score).
|
||||
|
||||
Steps:
|
||||
1. reward = -score (lower certainty = better = higher reward)
|
||||
2. mab.observe(arm, reward)
|
||||
3. mab.update() -> rank-1 theta/B_inv update
|
||||
4. mab.sample() -> selects next arms to explore
|
||||
"""
|
||||
if not results:
|
||||
return
|
||||
for arm_idx, _, score in results:
|
||||
if 0 <= arm_idx < len(self.pool.arms):
|
||||
self.mab.observe(arm_idx, -float(score))
|
||||
self.mab.update()
|
||||
try:
|
||||
next_cands = self.mab.sample()
|
||||
top = self.mab.recommend(self.queue_size)
|
||||
self._next = list(dict.fromkeys(next_cands + top))
|
||||
except Exception as e:
|
||||
print(f"[RecruiterAgent] MAB sample error ({e}), using recommend()")
|
||||
self._next = self.mab.recommend(self.queue_size)
|
||||
20
sc/logger.py
Normal file
20
sc/logger.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, log_path: str):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.log_path = f"{log_path}_{timestamp}"
|
||||
os.makedirs(self.log_path, exist_ok=True)
|
||||
|
||||
def log(self, message: str):
|
||||
print(message)
|
||||
with open(self.log_path, "a", encoding="utf-8") as f:
|
||||
f.write(message + "\n")
|
||||
|
||||
def log_config(self, config: dict):
|
||||
dumped = yaml.dump(config, default_flow_style=False, sort_keys=False)
|
||||
msg = "Config:\n" + dumped
|
||||
self.log(msg)
|
||||
@@ -1,31 +1,149 @@
|
||||
"""
|
||||
HuggingFace Causal LM wrapper for inference with logits.
|
||||
|
||||
Loads a model via transformers, generates text, and returns both
|
||||
the generated text and full logits via a forward pass.
|
||||
|
||||
Message format: [{"role": "user"|"assistant"|"system", "content": str}, ...]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_dir = "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# {"role": "user"|"assistant"|"system", "content": str}
|
||||
ChatMessage = Dict[str, str]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
||||
device_map="auto" if device == "cuda" else None,
|
||||
).eval()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Model
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Model:
|
||||
"""HuggingFace causal LM. Returns generated text + full logits."""
|
||||
|
||||
def __init__(self, model_path: str, temperature: float = 0.0, max_new_tokens: int = 1024) -> None:
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
||||
self._model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
).eval()
|
||||
self.temperature = temperature
|
||||
self.max_new_tokens = max_new_tokens
|
||||
|
||||
def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
|
||||
"""
|
||||
Generate a reply and return (text, logits).
|
||||
|
||||
Returns:
|
||||
text: Generated text string.
|
||||
logits: Full logits tensor (1, seq_len, vocab_size) on CPU.
|
||||
seq_len covers prompt + generated tokens.
|
||||
"""
|
||||
# Build prompt from chat messages
|
||||
prompt = self._tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True,
|
||||
)
|
||||
inputs = self._tokenizer(prompt, return_tensors="pt").to(self._model.device)
|
||||
prompt_len = inputs["input_ids"].shape[1]
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
gen_out = self._model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
do_sample=self.temperature > 0,
|
||||
temperature=self.temperature if self.temperature > 0 else 1.0,
|
||||
)
|
||||
|
||||
# Decode generated tokens
|
||||
gen_ids = gen_out[0][prompt_len:]
|
||||
text = self._tokenizer.decode(gen_ids, skip_special_tokens=True)
|
||||
|
||||
# Forward pass on full sequence to get logits
|
||||
with torch.no_grad():
|
||||
logits = self._model(gen_out).logits # (1, seq_len, vocab_size)
|
||||
|
||||
return text, logits.cpu()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Async wrappers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AsyncModel:
|
||||
"""Runs Model.invoke in a thread executor for async usage."""
|
||||
|
||||
def __init__(self, model: Model) -> None:
|
||||
self.model = model
|
||||
|
||||
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, lambda: self.model.invoke(messages),
|
||||
)
|
||||
|
||||
|
||||
class AsyncModelPool:
|
||||
"""Round-robin pool of async models."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.models: List[Model] = []
|
||||
self._queue: Optional[asyncio.Queue[AsyncModel]] = None
|
||||
|
||||
def add_model(self, model: Model) -> None:
|
||||
self.models.append(model)
|
||||
|
||||
def init_models(self) -> None:
|
||||
print(f"Initializing {len(self.models)} model(s)...")
|
||||
self._queue = asyncio.Queue()
|
||||
for model in self.models:
|
||||
self._queue.put_nowait(AsyncModel(model))
|
||||
|
||||
async def invoke(self, messages: List[ChatMessage]) -> Tuple[str, torch.Tensor]:
|
||||
"""Acquire a model, invoke, release. Returns (text, logits)."""
|
||||
if self._queue is None:
|
||||
raise RuntimeError("Call init_models() first.")
|
||||
|
||||
async_model = await self._queue.get()
|
||||
try:
|
||||
return await async_model.invoke(messages)
|
||||
finally:
|
||||
self._queue.put_nowait(async_model)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Public API
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_models(
|
||||
models: List[str],
|
||||
temperature: float = 0.0,
|
||||
max_new_tokens: int = 1024,
|
||||
) -> AsyncModelPool:
|
||||
"""Create and initialize a model pool from a list of model paths."""
|
||||
pool = AsyncModelPool()
|
||||
for path in models:
|
||||
pool.add_model(
|
||||
Model(path, temperature=temperature, max_new_tokens=max_new_tokens)
|
||||
)
|
||||
pool.init_models()
|
||||
return pool
|
||||
|
||||
|
||||
# test
|
||||
model = Model("/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b")
|
||||
messages = [
|
||||
{"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
|
||||
]
|
||||
|
||||
# Convert chat messages -> a single prompt string using the model's chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(**inputs)
|
||||
|
||||
logits = out.logits # shape: (batch=1, seq_len, vocab_size)
|
||||
print("logits shape:", logits.shape)
|
||||
text, logits = model.invoke(messages)
|
||||
print(text)
|
||||
print(logits.shape)
|
||||
99
sc/preprocess/build_feature_index.py
Normal file
99
sc/preprocess/build_feature_index.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Build unified feature index from PPGBP per-subject .npy embeddings.
|
||||
Output: <output_dir>/index.json + <output_dir>/embeddings.npy
|
||||
|
||||
Usage:
|
||||
python -m sc.preprocess.build_feature_index --embedding_root $ROOT --output_dir ./features/sbert_metadata
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _load_labels(base_dir: Optional[str]) -> Dict[int, str]:
|
||||
"""Load subject_id -> label from xlsx. Prefers Hypertension column (PPGBP), else class/Class."""
|
||||
if not base_dir:
|
||||
return {}
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
return {}
|
||||
xlsx = None
|
||||
for f in os.listdir(base_dir):
|
||||
if f.endswith(".xlsx"):
|
||||
xlsx = os.path.join(base_dir, f)
|
||||
break
|
||||
if not xlsx or not os.path.isfile(xlsx):
|
||||
return {}
|
||||
df = pd.read_excel(xlsx, header=1)
|
||||
if "subject_ID" not in df.columns:
|
||||
return {}
|
||||
label_col = None
|
||||
for c in ("Hypertension", "class", "Class"):
|
||||
if c in df.columns:
|
||||
label_col = c
|
||||
break
|
||||
out = {}
|
||||
for _, row in df.iterrows():
|
||||
sid = int(row["subject_ID"])
|
||||
out[sid] = str(row.get(label_col, "unknown")).strip() if label_col else "unknown"
|
||||
return out
|
||||
|
||||
|
||||
def build_index(
|
||||
embedding_root: str,
|
||||
output_dir: str,
|
||||
embedding_type: str = "sbert_metadata",
|
||||
base_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
if not os.path.isdir(embedding_root):
|
||||
raise FileNotFoundError(embedding_root)
|
||||
subject_ids = []
|
||||
for f in sorted(os.listdir(embedding_root)):
|
||||
if not f.endswith(".npy"):
|
||||
continue
|
||||
try:
|
||||
subject_ids.append(int(f.replace(".npy", "")))
|
||||
except ValueError:
|
||||
continue
|
||||
subject_ids = sorted(subject_ids)
|
||||
if not subject_ids:
|
||||
raise ValueError("No .npy files in " + embedding_root)
|
||||
labels = _load_labels(base_dir)
|
||||
embeddings_list = []
|
||||
samples: List[Dict[str, Any]] = []
|
||||
for row, sid in enumerate(subject_ids):
|
||||
emb = np.load(os.path.join(embedding_root, f"{sid}.npy")).astype(np.float32)
|
||||
embeddings_list.append(emb)
|
||||
samples.append({
|
||||
"row": row,
|
||||
"user_id": str(sid),
|
||||
"sample_id": sid,
|
||||
"label": labels.get(sid, "unknown"),
|
||||
"session": "1",
|
||||
})
|
||||
embeddings = np.stack(embeddings_list, axis=0)
|
||||
dim = int(embeddings.shape[1])
|
||||
meta = {"version": "1.0", "embedding_type": embedding_type, "embedding_dim": dim, "samples": samples}
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
with open(os.path.join(output_dir, "index.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
np.save(os.path.join(output_dir, "embeddings.npy"), embeddings)
|
||||
print(f"[build_feature_index] Wrote {len(samples)} samples, dim={dim}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--embedding_root", required=True, help="Directory of per-subject .npy files")
|
||||
ap.add_argument("--output_dir", required=True, help="Output dir for index.json and embeddings.npy")
|
||||
ap.add_argument("--embedding_type", default="sbert_metadata")
|
||||
ap.add_argument("--base_dir", default=None, help="Optional PPGBP data dir (xlsx) for labels")
|
||||
a = ap.parse_args()
|
||||
build_index(a.embedding_root, a.output_dir, a.embedding_type, a.base_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
134
sc/preprocess/build_ppgbp_embeddings.py
Normal file
134
sc/preprocess/build_ppgbp_embeddings.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Build PPGBP embeddings under PPGBP_EMBEDDINGS_ROOT with two top-level dirs:
|
||||
metadata_embs/<emb_type>/<subject_id>.npy
|
||||
sig_embs/<emb_type>/<subject_id>.npy
|
||||
|
||||
For now only metadata embeddings are produced: for each subject, take the xlsx row,
|
||||
vectorize (one-hot for Sex and Hypertension, numeric for the rest) and save.
|
||||
|
||||
Usage:
|
||||
python -m sc.preprocess.build_ppgbp_embeddings \\
|
||||
--embeddings_root /path/to/ppgbp_embeddings \\
|
||||
--data_dir /path/to/ppgbp_data \\
|
||||
--metadata_emb_type onehot
|
||||
|
||||
Then build the feature index (for downstream RecruiterAgent etc.):
|
||||
python -m sc.preprocess.build_feature_index \\
|
||||
--embedding_root $PPGBP_EMBEDDINGS_ROOT/metadata_embs/onehot \\
|
||||
--output_dir $DATA_INDEX_ROOT/metadata_onehot \\
|
||||
--embedding_type metadata_onehot \\
|
||||
--base_dir $PPGBP_DATA_DIR
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# Xlsx column names (header=1)
|
||||
SUBJECT_ID_COL = "subject_ID"
|
||||
SEX_COL = "Sex(M/F)"
|
||||
AGE_COL = "Age(year)"
|
||||
HEIGHT_COL = "Height(cm)"
|
||||
WEIGHT_COL = "Weight(kg)"
|
||||
SBP_COL = "Systolic Blood Pressure(mmHg)"
|
||||
DBP_COL = "Diastolic Blood Pressure(mmHg)"
|
||||
HR_COL = "Heart Rate(b/m)"
|
||||
BMI_COL = "BMI(kg/m^2)"
|
||||
HYP_COL = "Hypertension"
|
||||
|
||||
NUMERIC_COLS = [AGE_COL, HEIGHT_COL, WEIGHT_COL, SBP_COL, DBP_COL, HR_COL, BMI_COL]
|
||||
CATEGORICAL_COLS = [SEX_COL, HYP_COL]
|
||||
|
||||
|
||||
def _find_xlsx(data_dir: str) -> str:
|
||||
for f in os.listdir(data_dir):
|
||||
if f.endswith(".xlsx"):
|
||||
return os.path.join(data_dir, f)
|
||||
raise FileNotFoundError(f"No .xlsx in {data_dir}")
|
||||
|
||||
|
||||
def _onehot_and_numeric(df: pd.DataFrame) -> Tuple[np.ndarray, List[str], Dict[str, List[str]]]:
|
||||
"""
|
||||
One-hot encode categoricals (Sex, Hypertension) and stack with numeric columns.
|
||||
Returns (matrix, column_names, categories_used).
|
||||
"""
|
||||
out_cols: List[str] = []
|
||||
pieces: List[np.ndarray] = []
|
||||
categories_used: Dict[str, List[str]] = {}
|
||||
|
||||
for col in CATEGORICAL_COLS:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
vals = df[col].fillna("unknown").astype(str).str.strip()
|
||||
uniq = sorted(vals.unique())
|
||||
categories_used[col] = uniq
|
||||
for u in uniq:
|
||||
out_cols.append(f"{col}__{u}")
|
||||
onehot = (vals.values.reshape(-1, 1) == np.array(uniq)).astype(np.float32)
|
||||
pieces.append(onehot)
|
||||
|
||||
for col in NUMERIC_COLS:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
out_cols.append(col)
|
||||
vec = df[col].fillna(0).astype(np.float32).values.reshape(-1, 1)
|
||||
pieces.append(vec)
|
||||
|
||||
if not pieces:
|
||||
raise ValueError("No columns found; check xlsx has expected column names")
|
||||
X = np.hstack(pieces)
|
||||
return X, out_cols, categories_used
|
||||
|
||||
|
||||
def create_embedding_dirs(embeddings_root: str, metadata_emb_type: str, sig_emb_type: str = "placeholder") -> None:
|
||||
"""Create metadata_embs/<emb_type>/ and sig_embs/<emb_type>/."""
|
||||
os.makedirs(embeddings_root, exist_ok=True)
|
||||
meta_dir = os.path.join(embeddings_root, "metadata_embs", metadata_emb_type)
|
||||
sig_dir = os.path.join(embeddings_root, "sig_embs", sig_emb_type)
|
||||
os.makedirs(meta_dir, exist_ok=True)
|
||||
os.makedirs(sig_dir, exist_ok=True)
|
||||
|
||||
|
||||
def build_metadata_embeddings_onehot(
|
||||
data_dir: str,
|
||||
embeddings_root: str,
|
||||
metadata_emb_type: str = "onehot",
|
||||
) -> None:
|
||||
"""
|
||||
Read xlsx from data_dir, vectorize each subject row (one-hot + numeric), save as
|
||||
metadata_embs/<metadata_emb_type>/<subject_id>.npy.
|
||||
"""
|
||||
xlsx_path = _find_xlsx(data_dir)
|
||||
df = pd.read_excel(xlsx_path, header=1)
|
||||
if SUBJECT_ID_COL not in df.columns:
|
||||
raise ValueError(f"xlsx must have column {SUBJECT_ID_COL}")
|
||||
|
||||
X, _cols, _cats = _onehot_and_numeric(df)
|
||||
subject_ids = df[SUBJECT_ID_COL].astype(int).values
|
||||
out_dir = os.path.join(embeddings_root, "metadata_embs", metadata_emb_type)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
for i, sid in enumerate(subject_ids):
|
||||
path = os.path.join(out_dir, f"{int(sid)}.npy")
|
||||
np.save(path, X[i].astype(np.float32))
|
||||
print(f"[build_ppgbp_embeddings] Wrote {len(subject_ids)} metadata embeddings to {out_dir} (dim={X.shape[1]})")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description="Build PPGBP metadata (and optionally signal) embeddings under embeddings_root.")
|
||||
ap.add_argument("--embeddings_root", required=True, help="Root for metadata_embs/ and sig_embs/ (e.g. PPGBP_EMBEDDINGS_ROOT)")
|
||||
ap.add_argument("--data_dir", required=True, help="PPGBP data dir containing the xlsx file")
|
||||
ap.add_argument("--metadata_emb_type", default="onehot", help="Subdir name under metadata_embs/ (e.g. onehot)")
|
||||
ap.add_argument("--sig_emb_type", default="placeholder", help="Subdir under sig_embs/ (created empty for now)")
|
||||
args = ap.parse_args()
|
||||
create_embedding_dirs(args.embeddings_root, args.metadata_emb_type, args.sig_emb_type)
|
||||
build_metadata_embeddings_onehot(args.data_dir, args.embeddings_root, args.metadata_emb_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
76
sc/preprocess/package_ppgbp_embeddings.py
Normal file
76
sc/preprocess/package_ppgbp_embeddings.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
One-shot: build PPGBP one-hot metadata embeddings and then package them into
|
||||
the feature index at DATA_INDEX_ROOT for downstream use (RecruiterAgent, etc.).
|
||||
|
||||
Step 1: Create metadata_embs/onehot/ and sig_embs/ under PPGBP_EMBEDDINGS_ROOT,
|
||||
write one-hot metadata vectors per subject.
|
||||
Step 2: Run build_feature_index to produce index.json + embeddings.npy at
|
||||
DATA_INDEX_ROOT (compatible with FeatureIndex and RecruiterAgent).
|
||||
|
||||
Usage:
|
||||
export PPGBP_EMBEDDINGS_ROOT=/scratch/.../embeddings_full/ppgbp
|
||||
export PPGBP_DATA_DIR=/path/to/ppgbp_data # dir with xlsx and 0_subject/
|
||||
# or: export PPGBP_DATA_ROOT=/path/to/data/tsllm (same as data dir if xlsx lives there)
|
||||
export DATA_INDEX_ROOT=/path/to/feature_index_output
|
||||
|
||||
python -m sc.preprocess.package_ppgbp_embeddings
|
||||
|
||||
Or with explicit args:
|
||||
python -m sc.preprocess.package_ppgbp_embeddings \\
|
||||
--embeddings_root $PPGBP_EMBEDDINGS_ROOT \\
|
||||
--data_dir $PPGBP_DATA_DIR \\
|
||||
--index_output_dir $DATA_INDEX_ROOT \\
|
||||
--metadata_emb_type onehot \\
|
||||
--embedding_type metadata_onehot
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description="Build PPGBP metadata embeddings and package feature index.")
|
||||
ap.add_argument("--embeddings_root", default=None, help="Default: PPGBP_EMBEDDINGS_ROOT env")
|
||||
ap.add_argument("--data_dir", default=None, help="PPGBP data dir (xlsx + 0_subject/); default: PPGBP_DATA_DIR or PPGBP_DATA_ROOT env")
|
||||
ap.add_argument("--index_output_dir", default=None, help="Where to write index.json + embeddings.npy; default: DATA_INDEX_ROOT env")
|
||||
ap.add_argument("--metadata_emb_type", default="onehot")
|
||||
ap.add_argument("--embedding_type", default="metadata_onehot", help="embedding_type string in index.json")
|
||||
args = ap.parse_args()
|
||||
|
||||
embeddings_root = args.embeddings_root or os.environ.get("PPGBP_EMBEDDINGS_ROOT")
|
||||
data_dir = args.data_dir or os.environ.get("PPGBP_DATA_DIR") or os.environ.get("PPGBP_DATA_ROOT") or embeddings_root
|
||||
index_output_dir = args.index_output_dir or os.environ.get("DATA_INDEX_ROOT")
|
||||
|
||||
if not embeddings_root:
|
||||
print("Set PPGBP_EMBEDDINGS_ROOT or pass --embeddings_root", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not os.path.isdir(data_dir):
|
||||
print(f"Data dir not found: {data_dir}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not index_output_dir:
|
||||
print("Set DATA_INDEX_ROOT or pass --index_output_dir", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
from sc.preprocess.build_ppgbp_embeddings import (
|
||||
create_embedding_dirs,
|
||||
build_metadata_embeddings_onehot,
|
||||
)
|
||||
create_embedding_dirs(embeddings_root, args.metadata_emb_type)
|
||||
build_metadata_embeddings_onehot(data_dir, embeddings_root, args.metadata_emb_type)
|
||||
|
||||
embedding_root_for_index = os.path.join(embeddings_root, "metadata_embs", args.metadata_emb_type)
|
||||
from sc.preprocess.build_feature_index import build_index
|
||||
build_index(
|
||||
embedding_root=embedding_root_for_index,
|
||||
output_dir=index_output_dir,
|
||||
embedding_type=args.embedding_type,
|
||||
base_dir=data_dir,
|
||||
)
|
||||
print(f"[package_ppgbp_embeddings] Feature index written to {index_output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -30,7 +30,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
|
||||
from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.example_queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
@@ -425,7 +425,7 @@ def main(
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
num_ctx=config.get("num_ctx", 15000),
|
||||
max_new_tokens=config.get("max_new_tokens", 1024),
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
|
||||
@@ -32,7 +32,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
|
||||
from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.example_queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
@@ -435,7 +435,7 @@ def main(
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
num_ctx=config.get("num_ctx", 15000),
|
||||
max_new_tokens=config.get("max_new_tokens", 1024),
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
|
||||
390
sc/run_recruiter.py
Normal file
390
sc/run_recruiter.py
Normal file
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Recruiter-based Self-Consistency runner for PPGBP.
|
||||
|
||||
Uses RecruiterAgent (MAB) to select example sets, self_certainty(logits) as reward,
|
||||
and borda_vote for aggregation. Requires PPGBPLoader and optional feature index.
|
||||
|
||||
Usage:
|
||||
python -m sc.run_recruiter sc/config/ppgbp_recruiter.yaml
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from fire import Fire
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sc.core.data_loader import PPGBPLoader, get_loaders
|
||||
from sc.core.example_queue import Queue
|
||||
from sc.core.model import Model, load_models
|
||||
from sc.core.recruiter_agent import RecruiterAgent, borda_vote, self_certainty
|
||||
from sc.core.recruiter_agent import FeatureIndex
|
||||
|
||||
|
||||
def _ppgbp_sample_to_dict(metadata: Dict, signal, subject_id: int, label_key: str = "label") -> Dict[str, Any]:
|
||||
"""Turn (metadata, signal) into a sample dict with user_id, sample_id, label, features."""
|
||||
# Try to find label in metadata using label_key, then "hypertension", then "Hypertension", then "class", then "Class"
|
||||
label = str(metadata.get(label_key))
|
||||
if label == "None" or label == "unknown":
|
||||
for key in ["hypertension", "Hypertension", "class", "Class"]:
|
||||
val = metadata.get(key)
|
||||
if val is not None:
|
||||
label = str(val)
|
||||
break
|
||||
if label == "None":
|
||||
label = "unknown"
|
||||
|
||||
features = {k: str(v) for k, v in metadata.items() if k != label_key and k.lower() != "label"}
|
||||
return {
|
||||
"user_id": str(subject_id),
|
||||
"sample_id": subject_id,
|
||||
"idx": subject_id,
|
||||
"label": label,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
def _build_prompt(system_msg: str, task_info: str, classes_info: List[str], test_sample: Dict, examples: List[Dict]) -> List[Dict[str, str]]:
|
||||
"""Build chat messages for the model."""
|
||||
parts = [f"**Task**: {task_info}\n\n**Classes**: {', '.join(classes_info)}\n\n"]
|
||||
parts.append("**Examples**\n")
|
||||
for ex in examples:
|
||||
parts.append(f"*Example of {ex['label']}*:\n")
|
||||
for k, v in (ex.get("features") or {}).items():
|
||||
parts.append(f" - {k}: {v}\n")
|
||||
parts.append("\n**Current sample**\n")
|
||||
for k, v in (test_sample.get("features") or {}).items():
|
||||
parts.append(f" - {k}: {v}\n")
|
||||
parts.append(f"\nRespond in JSON: {{\"REASON\": \"...\", \"CONFIDENCE\": 0.0-1.0, \"ANSWER\": \"<class>\"}}")
|
||||
content = "".join(parts)
|
||||
return [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": content},
|
||||
]
|
||||
|
||||
|
||||
def _parse_answer(text: str, classes: List[str]) -> Optional[str]:
|
||||
"""Extract ANSWER from JSON in text."""
|
||||
try:
|
||||
start = text.find("{")
|
||||
end = text.rfind("}") + 1
|
||||
if start >= 0 and end > start:
|
||||
obj = json.loads(text[start:end])
|
||||
ans = (obj.get("ANSWER") or "").strip()
|
||||
if ans in classes:
|
||||
return ans
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def make_gt_scorer(
|
||||
num_arms: int,
|
||||
best_arm_id: int = 20,
|
||||
sigma: float = 5.0,
|
||||
) -> Callable[[int], float]:
|
||||
"""
|
||||
Ground-truth: deterministic. The distribution *over arms* is normal: reward(arm_id)
|
||||
is a Gaussian in arm index with peak at best_arm_id. So reward(i) = exp(-(i - best)^2 / (2 sigma^2)).
|
||||
Score = -reward so lower = better. No randomness.
|
||||
"""
|
||||
best_arm_id = max(0, min(best_arm_id, num_arms - 1))
|
||||
sigma = max(1e-6, float(sigma))
|
||||
|
||||
def scorer(arm_id: int) -> float:
|
||||
if arm_id < 0 or arm_id >= num_arms:
|
||||
return 0.0
|
||||
reward = np.exp(-((arm_id - best_arm_id) ** 2) / (2 * sigma**2))
|
||||
return -float(reward)
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
async def run_single_user(
|
||||
config: Dict[str, Any],
|
||||
model_pool: Optional[Any],
|
||||
test_loader: Optional[PPGBPLoader],
|
||||
train_loader: PPGBPLoader,
|
||||
seed: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run recruiter pipeline. If model_pool is None, use ground-truth scorer (arm_id -> score) for MAB convergence."""
|
||||
queue_size = config.get("queue_size", 5)
|
||||
recruit_size = config.get("recruit_size", 2)
|
||||
n_way = config.get("n_way", 2)
|
||||
k_shot = config.get("k_shot", 1)
|
||||
task_info = config.get("task_info", "Classify the subject.")
|
||||
classes_info = config.get("classes_info", ["Normal", "Hypertension"])
|
||||
system_msg = config.get("system_message", "You are a helpful assistant. Respond in JSON with REASON, CONFIDENCE, ANSWER.")
|
||||
use_gt_only = model_pool is None
|
||||
|
||||
np.random.seed(seed)
|
||||
train_loader._indices = np.arange(len(train_loader.subject_ids))
|
||||
if train_loader.shuffle:
|
||||
train_loader.rng.shuffle(train_loader._indices)
|
||||
|
||||
example_dataset = []
|
||||
for idx in train_loader._indices:
|
||||
sid = train_loader.subject_ids[idx]
|
||||
meta = train_loader._get_metadata_dict(sid)
|
||||
signal = train_loader._load_signal(sid)
|
||||
example_dataset.append(_ppgbp_sample_to_dict(meta, signal, int(sid)))
|
||||
if len(example_dataset) < n_way * k_shot:
|
||||
print("[run_recruiter] Not enough train samples")
|
||||
return []
|
||||
|
||||
classes = list({s["label"] for s in example_dataset})
|
||||
if not classes:
|
||||
classes = classes_info
|
||||
class_indices = {c: [i for i, s in enumerate(example_dataset) if s["label"] == c] for c in classes}
|
||||
for c in classes:
|
||||
if not class_indices[c]:
|
||||
class_indices[c] = [0]
|
||||
dataset_idx = {(str(s["user_id"]), s["sample_id"]): i for i, s in enumerate(example_dataset)}
|
||||
|
||||
feature_index = None
|
||||
if config.get("feature_root") and os.path.isdir(config["feature_root"]):
|
||||
try:
|
||||
feature_index = FeatureIndex(config["feature_root"])
|
||||
except Exception as e:
|
||||
print(f"[run_recruiter] Feature index load failed: {e}")
|
||||
|
||||
recruiter = RecruiterAgent(
|
||||
source_samples=example_dataset,
|
||||
classes=classes,
|
||||
feature_index=feature_index,
|
||||
target_user_id=None,
|
||||
k_shot=k_shot,
|
||||
queue_size=queue_size,
|
||||
num_arms=config.get("num_arms", 50),
|
||||
n_das=config.get("n_das", 5),
|
||||
delta=config.get("delta", 0.05),
|
||||
epsilon=config.get("epsilon", 0.11),
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
num_arms = len(recruiter.pool.arms)
|
||||
if use_gt_only:
|
||||
gt_steps = config.get("gt_steps", 200)
|
||||
gt_best_arm_id = config.get("gt_best_arm_id", 20)
|
||||
gt_sigma = config.get("gt_sigma", 5.0)
|
||||
gt_scorer = make_gt_scorer(num_arms, best_arm_id=gt_best_arm_id, sigma=gt_sigma)
|
||||
print(f"[run_recruiter] GT-only mode: {gt_steps} steps, best_arm_id={gt_best_arm_id}, sigma={gt_sigma}")
|
||||
|
||||
initial_sets = recruiter.recruit(queue_size)
|
||||
initial_cases = []
|
||||
slot_arms = []
|
||||
for arm_idx, ex_set in initial_sets:
|
||||
case = []
|
||||
for d in ex_set:
|
||||
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
|
||||
if key in dataset_idx:
|
||||
case.append(dataset_idx[key])
|
||||
else:
|
||||
for i, s in enumerate(example_dataset):
|
||||
if s.get("user_id") == str(d.get("user_id")) and s.get("sample_id") == d.get("sample_id"):
|
||||
case.append(i)
|
||||
break
|
||||
if len(case) >= n_way:
|
||||
initial_cases.append(case[:n_way])
|
||||
slot_arms.append(arm_idx)
|
||||
while len(initial_cases) < queue_size and len(initial_sets) > len(initial_cases):
|
||||
arm_idx, ex_set = initial_sets[len(initial_cases)]
|
||||
case = []
|
||||
for d in ex_set:
|
||||
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
|
||||
case.append(dataset_idx.get(key, 0))
|
||||
if len(case) < n_way:
|
||||
case.extend([0] * (n_way - len(case)))
|
||||
initial_cases.append(case[:n_way])
|
||||
slot_arms.append(arm_idx)
|
||||
|
||||
ex_queue = Queue(class_indices, queue_size)
|
||||
ex_queue._queue = deque(initial_cases[:queue_size], maxlen=queue_size)
|
||||
for case in ex_queue._queue:
|
||||
ex_queue._register_stats(case)
|
||||
slot_arms = slot_arms[:queue_size]
|
||||
|
||||
results = []
|
||||
processed = 0
|
||||
cumulative_correct = 0
|
||||
|
||||
if use_gt_only:
|
||||
# Loop: query gt_scorer(arm_id) for each slot, update MAB, recruit, update queue
|
||||
for step in range(gt_steps):
|
||||
ex_queue.set_current_time(step)
|
||||
queue_cases = list(ex_queue._queue)
|
||||
run_results = []
|
||||
for slot_i in range(len(slot_arms)):
|
||||
arm_idx = slot_arms[slot_i]
|
||||
score = gt_scorer(arm_idx)
|
||||
response_dict = {"ANSWER": None, "REASON": "gt", "CONFIDENCE": 0.5}
|
||||
run_results.append((arm_idx, response_dict, score))
|
||||
|
||||
recruiter.update(run_results)
|
||||
new_sets = recruiter.recruit(recruit_size)
|
||||
new_cases = []
|
||||
new_arm_ids = []
|
||||
for arm_idx, ex_set in new_sets:
|
||||
case = []
|
||||
for d in ex_set:
|
||||
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
|
||||
case.append(dataset_idx.get(key, 0))
|
||||
if len(case) < n_way:
|
||||
case.extend([0] * (n_way - len(case)))
|
||||
new_cases.append(case[:n_way])
|
||||
new_arm_ids.append(arm_idx)
|
||||
ex_queue.update_with_recruiter(run_results, new_cases, recruit_size)
|
||||
scores_with_idx = [(i, run_results[i][2]) for i in range(len(run_results))]
|
||||
scores_with_idx.sort(key=lambda x: x[1], reverse=True)
|
||||
to_evict = {scores_with_idx[j][0] for j in range(min(recruit_size, len(scores_with_idx)))}
|
||||
kept = [i for i in range(len(slot_arms)) if i not in to_evict]
|
||||
slot_arms = [slot_arms[i] for i in kept] + new_arm_ids[:recruit_size]
|
||||
|
||||
results.append({
|
||||
"sample_idx": step,
|
||||
"answer": None,
|
||||
"ground_truth": None,
|
||||
"is_correct": None,
|
||||
"gt_step": step,
|
||||
"queue_arms": list(slot_arms),
|
||||
"step_rewards": [r[2] for r in run_results],
|
||||
})
|
||||
print(f"[GT Summary] Final queue arms: {slot_arms}")
|
||||
return results
|
||||
|
||||
# Model mode: iterate test loader, call LLM, self_certainty, borda_vote
|
||||
for step, (metadata, signal) in enumerate(test_loader):
|
||||
test_sample = _ppgbp_sample_to_dict(metadata, signal, int(metadata.get("subject_ID", step)))
|
||||
ex_queue.set_current_time(processed)
|
||||
queue_cases = list(ex_queue._queue)
|
||||
run_results = []
|
||||
for slot_i, case in enumerate(queue_cases):
|
||||
examples = [example_dataset[i] for i in case if i < len(example_dataset)]
|
||||
if len(examples) < n_way:
|
||||
continue
|
||||
messages = _build_prompt(system_msg, task_info, classes_info, test_sample, examples)
|
||||
text, logits = await model_pool.invoke(messages)
|
||||
score = self_certainty(logits)
|
||||
parsed = _parse_answer(text, classes_info)
|
||||
response_dict = {"ANSWER": parsed, "REASON": text[:200], "CONFIDENCE": 0.5}
|
||||
arm_idx = slot_arms[slot_i] if slot_i < len(slot_arms) else slot_i
|
||||
run_results.append((arm_idx, response_dict, score))
|
||||
|
||||
if not run_results:
|
||||
processed += 1
|
||||
continue
|
||||
answer = borda_vote(run_results, valid_classes=classes_info)
|
||||
gt = test_sample.get("label", "")
|
||||
is_correct = answer == gt
|
||||
cumulative_correct += 1 if is_correct else 0
|
||||
acc = cumulative_correct / (processed + 1)
|
||||
|
||||
recruiter.update(run_results)
|
||||
new_sets = recruiter.recruit(recruit_size)
|
||||
new_cases = []
|
||||
new_arm_ids = []
|
||||
for arm_idx, ex_set in new_sets:
|
||||
case = []
|
||||
for d in ex_set:
|
||||
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
|
||||
case.append(dataset_idx.get(key, 0))
|
||||
if len(case) < n_way:
|
||||
case.extend([0] * (n_way - len(case)))
|
||||
new_cases.append(case[:n_way])
|
||||
new_arm_ids.append(arm_idx)
|
||||
ex_queue.update_with_recruiter(run_results, new_cases, recruit_size)
|
||||
scores_with_idx = [(i, run_results[i][2]) for i in range(len(run_results))]
|
||||
scores_with_idx.sort(key=lambda x: x[1], reverse=True)
|
||||
to_evict = {scores_with_idx[j][0] for j in range(min(recruit_size, len(scores_with_idx)))}
|
||||
kept = [i for i in range(len(slot_arms)) if i not in to_evict]
|
||||
slot_arms = [slot_arms[i] for i in kept] + new_arm_ids[:recruit_size]
|
||||
|
||||
results.append({
|
||||
"sample_idx": processed,
|
||||
"answer": answer,
|
||||
"ground_truth": gt,
|
||||
"is_correct": is_correct,
|
||||
"cumulative_accuracy": acc,
|
||||
})
|
||||
processed += 1
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _expand_env_in_config(obj: Any) -> Any:
|
||||
"""Recursively expand $VAR and ${VAR} in config strings."""
|
||||
if isinstance(obj, dict):
|
||||
return {k: _expand_env_in_config(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_expand_env_in_config(v) for v in obj]
|
||||
if isinstance(obj, str):
|
||||
return os.path.expandvars(obj)
|
||||
return obj
|
||||
|
||||
|
||||
def main(config_path: str) -> None:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
config = _expand_env_in_config(config)
|
||||
if not config.get("model_path") and config.get("models"):
|
||||
config["model_path"] = config["models"][0] if isinstance(config["models"], list) else config["models"]
|
||||
model_path = config.get("model_path") or ""
|
||||
if isinstance(model_path, str):
|
||||
model_path = model_path.strip()
|
||||
use_gt_only = not model_path
|
||||
if use_gt_only:
|
||||
config["use_gt_only"] = True
|
||||
print("[run_recruiter] model_path is null: using ground-truth scorer (arm_id -> score) for MAB convergence")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if "log_path" in config:
|
||||
config["log_path"] = f"{config['log_path']}_{timestamp}"
|
||||
|
||||
pool = None
|
||||
if not use_gt_only:
|
||||
pool = load_models(
|
||||
[model_path],
|
||||
temperature=config.get("temperature", 0.7),
|
||||
max_new_tokens=config.get("max_new_tokens", 256),
|
||||
)
|
||||
|
||||
data_path = config["data_path"]
|
||||
train_loader, test_loader = get_loaders(data_path, seed=config.get("seed", 42))
|
||||
if use_gt_only:
|
||||
test_loader = None # not used in GT loop
|
||||
seeds = range(config.get("num_seeds", 1))
|
||||
all_results = []
|
||||
for seed in seeds:
|
||||
res = asyncio.run(run_single_user(config, pool, test_loader, train_loader, seed))
|
||||
all_results.extend(res)
|
||||
|
||||
if use_gt_only:
|
||||
print(f"GT mode: completed {len(all_results)} steps")
|
||||
else:
|
||||
correct = sum(1 for r in all_results if r.get("is_correct"))
|
||||
total = len(all_results)
|
||||
acc = correct / total if total else 0
|
||||
print(f"Accuracy: {correct}/{total} = {acc:.4f}")
|
||||
log_path = config.get("log_path", "./logs/ppgbp_recruiter")
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
with open(os.path.join(log_path, "results.json"), "w") as f:
|
||||
json.dump({
|
||||
"use_gt_only": use_gt_only,
|
||||
"accuracy": (sum(1 for r in all_results if r.get("is_correct")) / len(all_results)) if all_results and not use_gt_only else None,
|
||||
"correct": sum(1 for r in all_results if r.get("is_correct")) if not use_gt_only else None,
|
||||
"total": len(all_results),
|
||||
"results": all_results,
|
||||
}, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
@@ -29,7 +29,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sc.core.data_loader import DataLoader
|
||||
from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.example_queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
@@ -426,7 +426,7 @@ def main(config_path: str) -> None:
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
num_ctx=config.get("num_ctx", 15000),
|
||||
max_new_tokens=config.get("max_new_tokens", 1024),
|
||||
)
|
||||
|
||||
# Run experiments
|
||||
|
||||
@@ -518,7 +518,7 @@ def main(
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
num_ctx=config.get("num_ctx", 15000),
|
||||
max_new_tokens=config.get("max_new_tokens", 1024),
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
|
||||
275
sc/run_usc.py
275
sc/run_usc.py
@@ -30,10 +30,12 @@ from fire import Fire
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
|
||||
from sc.core.data_loader import DataLoader
|
||||
from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.example_queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc.logger import Logger
|
||||
from sc import debug_log
|
||||
|
||||
log = debug_log.log
|
||||
@@ -48,33 +50,33 @@ async def run_consistency_experiment(
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run consistency-based queue policy experiment.
|
||||
|
||||
|
||||
Queue Update Policy:
|
||||
- After each inference, calculate per-agent consistency
|
||||
- Consistency = how many other agents agree with this agent's answer
|
||||
- Rank queue elements by consistency score
|
||||
- Evict the lowest consistency element
|
||||
- Add a new random ICL example set
|
||||
|
||||
|
||||
Args:
|
||||
dataloader: ShuffledDataLoader with pre-shuffled test data
|
||||
model_pool: Async model pool for LLM inference
|
||||
config: Experiment configuration
|
||||
user_id: User ID being processed
|
||||
shuffle_seed: Shuffle seed for reproducibility
|
||||
|
||||
|
||||
Returns:
|
||||
List of result dictionaries
|
||||
"""
|
||||
# Set seeds for reproducibility within experiment
|
||||
random.seed(shuffle_seed)
|
||||
np.random.seed(shuffle_seed)
|
||||
|
||||
|
||||
example_dataset = dataloader.get_examples()
|
||||
if len(example_dataset) == 0:
|
||||
log(f"[ERROR] No examples found for user {user_id}")
|
||||
return []
|
||||
|
||||
|
||||
# Build class_indices for Queue initialization
|
||||
class_indices = {}
|
||||
for idx, example in enumerate(example_dataset):
|
||||
@@ -82,11 +84,11 @@ async def run_consistency_experiment(
|
||||
if label not in class_indices:
|
||||
class_indices[label] = []
|
||||
class_indices[label].append(idx)
|
||||
|
||||
|
||||
# Initialize Queue
|
||||
queue_size = config.get("queue_size", 5)
|
||||
ex_queue = Queue(class_indices, queue_size)
|
||||
|
||||
|
||||
# Tracking variables
|
||||
results = []
|
||||
cumulative_correct = 0
|
||||
@@ -94,10 +96,10 @@ async def run_consistency_experiment(
|
||||
recent_results = []
|
||||
confidence_history = []
|
||||
consistency_history = []
|
||||
|
||||
|
||||
all_predictions = []
|
||||
all_ground_truths = []
|
||||
|
||||
|
||||
debug_log.log_policy_experiment_header(
|
||||
"CONSISTENCY-BASED",
|
||||
user_id,
|
||||
@@ -105,20 +107,22 @@ async def run_consistency_experiment(
|
||||
len(dataloader),
|
||||
queue_size,
|
||||
)
|
||||
|
||||
|
||||
for processed_count, sample in enumerate(dataloader):
|
||||
ex_queue.set_current_time(processed_count)
|
||||
|
||||
|
||||
# Log queue state before processing
|
||||
debug_log.log_policy_queue_state("CONSISTENCY", processed_count, user_id, ex_queue)
|
||||
|
||||
debug_log.log_policy_queue_state(
|
||||
"CONSISTENCY", processed_count, user_id, ex_queue
|
||||
)
|
||||
|
||||
# Create agent pool
|
||||
agent_pool = AgentPool(log_path=config["log_path"])
|
||||
|
||||
|
||||
try:
|
||||
for queue_idx, ex_idcs in enumerate(ex_queue):
|
||||
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
|
||||
|
||||
|
||||
agent = SCAgent(
|
||||
name="EEG sensing",
|
||||
index=queue_idx,
|
||||
@@ -134,46 +138,50 @@ async def run_consistency_experiment(
|
||||
except Exception as e:
|
||||
log(f"[ERROR] Failed to create agents: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
|
||||
if len(agent_pool.agents) == 0:
|
||||
log(f"[WARN] No agents created for sample {processed_count}")
|
||||
continue
|
||||
|
||||
|
||||
# Run parallel interpretation
|
||||
try:
|
||||
interpretation_result = await agent_pool.run_parallel_interpretation()
|
||||
except Exception as e:
|
||||
log(f"[ERROR] Interpretation failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
|
||||
if interpretation_result is None:
|
||||
log(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
continue
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
|
||||
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = (
|
||||
interpretation_result
|
||||
)
|
||||
|
||||
# Evaluate result
|
||||
ground_truth = sample["label"]
|
||||
is_correct = answer == ground_truth
|
||||
|
||||
|
||||
cumulative_correct += 1 if is_correct else 0
|
||||
cumulative_accuracy = cumulative_correct / (processed_count + 1)
|
||||
|
||||
|
||||
recent_results.append(1 if is_correct else 0)
|
||||
if len(recent_results) > window_size:
|
||||
recent_results.pop(0)
|
||||
window_accuracy = sum(recent_results) / len(recent_results)
|
||||
|
||||
|
||||
confidence_history.append(avg_confidence)
|
||||
consistency_history.append(consistency)
|
||||
|
||||
|
||||
all_predictions.append(answer)
|
||||
all_ground_truths.append(ground_truth)
|
||||
|
||||
|
||||
# Performance logging
|
||||
debug_log.log_policy_result(
|
||||
"CONSISTENCY",
|
||||
@@ -188,28 +196,32 @@ async def run_consistency_experiment(
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
)
|
||||
|
||||
|
||||
# CONSISTENCY-BASED Queue Update
|
||||
if responses:
|
||||
# Calculate per-agent consistency: how many other agents agree
|
||||
all_answers = [r.get("ANSWER") for r in responses.values()]
|
||||
consistency_map = {}
|
||||
|
||||
|
||||
for idx, response in responses.items():
|
||||
agent_answer = response.get("ANSWER")
|
||||
# Consistency = ratio of agents (including self) that agree
|
||||
agreement_count = all_answers.count(agent_answer)
|
||||
agent_consistency = agreement_count / len(all_answers) if all_answers else 0
|
||||
agent_consistency = (
|
||||
agreement_count / len(all_answers) if all_answers else 0
|
||||
)
|
||||
consistency_map[idx] = agent_consistency
|
||||
|
||||
|
||||
debug_log.log_consistency_map(responses, consistency_map)
|
||||
|
||||
|
||||
# Update queue by consistency (reuses confidence method with consistency scores)
|
||||
ex_queue.update_by_confidence(consistency_map)
|
||||
ex_queue.increment_usage(list(responses.keys()))
|
||||
|
||||
debug_log.log_policy_queue_state_after("Consistency", processed_count, ex_queue)
|
||||
|
||||
|
||||
debug_log.log_policy_queue_state_after(
|
||||
"Consistency", processed_count, ex_queue
|
||||
)
|
||||
|
||||
# Store result
|
||||
result = {
|
||||
"sample_idx": processed_count,
|
||||
@@ -227,34 +239,36 @@ async def run_consistency_experiment(
|
||||
"shuffle_seed": shuffle_seed,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
|
||||
# Final statistics
|
||||
survival_summary = ex_queue.get_survival_summary()
|
||||
debug_log.log_final_policy_survival("usc", user_id, shuffle_seed, survival_summary)
|
||||
|
||||
|
||||
if results:
|
||||
results[-1]["queue_survival_stats"] = survival_summary
|
||||
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
|
||||
def compute_statistics(
|
||||
results: List[Dict[str, Any]], stages: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Compute comprehensive experiment statistics."""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
|
||||
# Overall metrics
|
||||
correct = sum(1 for r in results if r.get("is_correct", False))
|
||||
total = len(results)
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
|
||||
|
||||
# Confidence and consistency averages
|
||||
confidences = [r.get("confidence", 0) for r in results]
|
||||
consistencies = [r.get("consistency", 0) for r in results]
|
||||
avg_confidence = np.mean(confidences) if confidences else 0
|
||||
avg_consistency = np.mean(consistencies) if consistencies else 0
|
||||
|
||||
|
||||
# Per-stage accuracy
|
||||
stage_correct = {}
|
||||
stage_total = {}
|
||||
@@ -263,47 +277,59 @@ def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict
|
||||
stage_total[gt] = stage_total.get(gt, 0) + 1
|
||||
if r.get("is_correct", False):
|
||||
stage_correct[gt] = stage_correct.get(gt, 0) + 1
|
||||
|
||||
|
||||
stage_accuracy = {}
|
||||
for stage in stages:
|
||||
if stage in stage_total:
|
||||
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
|
||||
else:
|
||||
stage_accuracy[stage] = 0.0
|
||||
|
||||
|
||||
# F1 Score and Macro metrics
|
||||
predictions = [r.get("answer", "") for r in results]
|
||||
ground_truths = [r.get("ground_truth", "") for r in results]
|
||||
|
||||
|
||||
try:
|
||||
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_f1 = f1_score(
|
||||
ground_truths, predictions, average="macro", zero_division=0
|
||||
)
|
||||
macro_precision = precision_score(
|
||||
ground_truths, predictions, average="macro", zero_division=0
|
||||
)
|
||||
macro_recall = recall_score(
|
||||
ground_truths, predictions, average="macro", zero_division=0
|
||||
)
|
||||
except Exception:
|
||||
macro_f1 = macro_precision = macro_recall = 0.0
|
||||
|
||||
|
||||
# Temporal analysis
|
||||
mid_point = len(results) // 2
|
||||
if mid_point > 0:
|
||||
first_half = results[:mid_point]
|
||||
second_half = results[mid_point:]
|
||||
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
|
||||
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
|
||||
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(
|
||||
first_half
|
||||
)
|
||||
second_half_acc = sum(
|
||||
1 for r in second_half if r.get("is_correct", False)
|
||||
) / len(second_half)
|
||||
improvement = second_half_acc - first_half_acc
|
||||
else:
|
||||
first_half_acc = second_half_acc = accuracy
|
||||
improvement = 0
|
||||
|
||||
|
||||
# Learning curve (every 10 samples)
|
||||
learning_curve = []
|
||||
for i in range(0, len(results), 10):
|
||||
chunk = results[:i+10]
|
||||
chunk = results[: i + 10]
|
||||
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
|
||||
learning_curve.append({
|
||||
"sample_idx": min(i+10, len(results)),
|
||||
"cumulative_accuracy": chunk_acc,
|
||||
})
|
||||
|
||||
learning_curve.append(
|
||||
{
|
||||
"sample_idx": min(i + 10, len(results)),
|
||||
"cumulative_accuracy": chunk_acc,
|
||||
}
|
||||
)
|
||||
|
||||
# Convergence speed (90% of final accuracy)
|
||||
final_accuracy = accuracy
|
||||
convergence_threshold = 0.9 * final_accuracy
|
||||
@@ -315,14 +341,14 @@ def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict
|
||||
if running_acc >= convergence_threshold:
|
||||
convergence_idx = i
|
||||
break
|
||||
|
||||
|
||||
# Queue statistics
|
||||
queue_stats = {}
|
||||
for r in results:
|
||||
if "queue_survival_stats" in r:
|
||||
queue_stats = r["queue_survival_stats"]
|
||||
break
|
||||
|
||||
|
||||
return {
|
||||
"experiment_type": "consistency",
|
||||
"user_id": results[0].get("user_id") if results else None,
|
||||
@@ -348,125 +374,38 @@ def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict
|
||||
}
|
||||
|
||||
|
||||
def save_results(
|
||||
results: List[Dict[str, Any]],
|
||||
stats: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
user_id: int,
|
||||
shuffle_seed: int,
|
||||
) -> None:
|
||||
"""Save experiment results and statistics."""
|
||||
log_path = config["log_path"]
|
||||
|
||||
# Create user/seed specific directory
|
||||
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save statistics
|
||||
stats_path = os.path.join(output_dir, "statistics.json")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||||
log(f"[SAVE] Statistics: {stats_path}")
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(output_dir, "results.json")
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
log(f"[SAVE] Results: {results_path}")
|
||||
|
||||
# Save config
|
||||
config_path = os.path.join(output_dir, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
log(f"[SAVE] Config: {config_path}")
|
||||
|
||||
|
||||
def main(
|
||||
config_path: str = "sc/config/experiment_consistency.yaml",
|
||||
user_id: int = 5,
|
||||
shuffle_seed: int = 42,
|
||||
) -> None:
|
||||
def main(config_path: str) -> None:
|
||||
"""
|
||||
Run Consistency-based Queue Policy experiment.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
user_id: User ID to process (5 or 10)
|
||||
shuffle_seed: Shuffle seed for data order (42, 123, or 456)
|
||||
|
||||
Example:
|
||||
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
|
||||
"""
|
||||
log(f"[MAIN] Loading config: {config_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
# Override with CLI arguments
|
||||
config["user_id"] = user_id
|
||||
config["shuffle_seed"] = shuffle_seed
|
||||
|
||||
# Create unique log path
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
|
||||
os.makedirs(config["log_path"], exist_ok=True)
|
||||
|
||||
# Print experiment info
|
||||
debug_log.log_policy_main_header(
|
||||
"CONSISTENCY-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
config.get("queue_size", 5),
|
||||
len(config.get("models", [])),
|
||||
config["log_path"],
|
||||
|
||||
logger = Logger(config["log_path"])
|
||||
logger.log_config(config)
|
||||
|
||||
dataloader = DataLoader(
|
||||
config["data_path"],
|
||||
config["user_id"],
|
||||
shuffle=config["shuffle"],
|
||||
seed=config["seed"],
|
||||
)
|
||||
|
||||
# Load shuffled data
|
||||
debug_log.log_policy_loading_data(user_id, shuffle_seed)
|
||||
dataloader = ShuffledDataLoader(
|
||||
data_path=config["data_path"],
|
||||
user_id=user_id,
|
||||
seed=shuffle_seed,
|
||||
example_pool=config.get("example_pool", "out"),
|
||||
)
|
||||
|
||||
# Load models
|
||||
debug_log.log_policy_loading_models()
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
num_ctx=config.get("num_ctx", 15000),
|
||||
temperature=config["temperature"],
|
||||
max_new_tokens=config.get("max_new_tokens", 1024),
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
debug_log.log_policy_start(label="experiment")
|
||||
results = asyncio.run(run_consistency_experiment(
|
||||
dataloader=dataloader,
|
||||
model_pool=model_pool,
|
||||
config=config,
|
||||
user_id=user_id,
|
||||
shuffle_seed=shuffle_seed,
|
||||
))
|
||||
|
||||
# Compute statistics
|
||||
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
|
||||
stats = compute_statistics(results, stages)
|
||||
|
||||
# Print final summary
|
||||
temporal = stats.get("temporal_analysis", {})
|
||||
debug_log.log_policy_complete_summary(
|
||||
"CONSISTENCY",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
stats,
|
||||
stats.get("stage_accuracy", {}),
|
||||
stats.get("stage_sample_counts", {}),
|
||||
temporal,
|
||||
|
||||
results = asyncio.run(
|
||||
run_consistency_experiment(
|
||||
dataloader=dataloader,
|
||||
model_pool=model_pool,
|
||||
config=config,
|
||||
user_id=user_id,
|
||||
shuffle_seed=shuffle_seed,
|
||||
)
|
||||
)
|
||||
|
||||
# Save results
|
||||
save_results(results, stats, config, user_id, shuffle_seed)
|
||||
log(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user