1 Commits

Author SHA1 Message Date
AadharshAadhithya
216bd2ecc3 case init 2026-03-10 20:31:05 -05:00
12 changed files with 2887 additions and 16 deletions

198
.claude/CLAUDE.md Normal file
View File

@@ -0,0 +1,198 @@
# Project Memory: tsllm_personalization_icl
## Project Overview
Research project on **personalized time-series classification using LLMs with In-Context Learning (ICL)**. The core idea is to use LLMs (via Ollama or HuggingFace) to classify sensor/physiological time-series data (e.g., EEG sleep stages, PPG blood pressure), where ICL examples are selected via different strategies (random vs. similarity-based) to study personalization.
Current branch: `case` (main branch: `main`)
---
## Repository Structure
```
tsllm_personalization_icl/
├── run.py # Main entry point (ICL experiment runner)
├── config/
│ └── sleepedf.yaml # Config: data path, models, selection criteria, log path
├── core/ # Base pipeline (ICL experiments)
│ ├── model.py # Model/AsyncModelPool (Ollama or LangChain init_chat_model)
│ ├── agent.py # Base Agent class (memory, logging, JSON parsing)
│ ├── sensing_agent.py # SensingAgent: ICL classify + evaluate + reflect
│ ├── data_loader.py # DataLoader: test/example splits + similarity-based selection
│ └── embedding_index.py # EmbeddingIndex: cosine similarity over Chronos-2 embeddings
├── sc/ # Self-Consistency (SC) variant pipeline
│ ├── run_sc.py # SC experiment runner (main entry: `python -m sc.run_sc config.yaml`)
│ ├── run_confidence_based.py
│ ├── run_consistency_based.py
│ ├── run_sc_queue_random.py
│ ├── run_usc.py
│ ├── core/
│ │ ├── model.py # HuggingFace CausalLM wrapper (returns text + logits)
│ │ ├── agent.py # SC base agent
│ │ ├── scagent.py # SCAgent: single-pass interpret() with REASON/CONFIDENCE/ANSWER
│ │ ├── agent_pool.py # AgentPool: parallel interpret + majority voting
│ │ ├── example_queue.py # Queue: priority queue of example sets, updated by confidence
│ │ ├── data_loader.py # DataLoader + InMemoryDataLoader + PPGBPLoader
│ │ ├── judge_agent.py # JudgeAgent
│ │ ├── majority_voting.py # MajorityVoting utilities
│ │ └── model_utils.py
│ ├── analysis/
│ │ └── analyze_sc_results.py
│ ├── preprocess/
│ │ └── shuffle_data.py
│ ├── debug_log.py # Structured debug logging helpers
│ ├── logger.py
│ ├── hf_api.py
│ └── ollama_test.py
├── analysis/
│ ├── ppgbp_loader.py
│ ├── user_similarity/ # Embedding analysis scripts
│ │ ├── chronos2/ # Chronos-2 embeddings for SleepEDF
│ │ ├── chronos2_ppgbp/ # Chronos-2 embeddings for PPGBP
│ │ ├── labram/ # LaBraM embeddings
│ │ ├── sbert/
│ │ ├── sbert_metadata/
│ │ └── sbert_metadata_ppgbp/ # SBERT metadata embeddings for PPGBP
│ ├── analyze_data.ipynb
│ └── analyze_preliminary.ipynb
├── preprocess/
│ ├── dhedfreader.py
│ └── preprocess_SleepEDF.py
├── utils/
│ ├── kill_ollamas.sh
│ └── launch_ollamas.sh
└── requirements.txt
```
---
## Key Concepts
### Datasets
- **SleepEDF**: EEG-based sleep stage classification (W, N1, N2, N3, REM).
- **PPGBP**: PPG-based blood pressure prediction.
- Data format: `<data_path>/<user_id>/1/` (train/examples, HuggingFace dataset on disk) and `<data_path>/<user_id>/2/` (test split). Metadata in `info.json` (keys: `task`, `class`, `feature`).
### ICL Selection Strategies (core/data_loader.py)
- `out_random`: Random examples from OTHER users (cross-user, random)
- `in_random`: Random examples from SAME user (personalized, random)
- `out_similar`: Most similar examples from OTHER users (Chronos-2 embedding similarity)
- `in_similar`: Most similar examples from SAME user (embedding similarity)
Similarity uses cosine similarity over Chronos-2 embeddings, one example per class (balanced).
### Models
- **core/model.py**: Supports Ollama (local/remote via `ollama:url:<host>/<model>`) and any LangChain-supported model (OpenAI, Together, etc.)
- **sc/core/model.py**: HuggingFace CausalLM loaded via `transformers`, returns text + logits. Used for SC experiments.
### Self-Consistency (sc/) Pipeline
The SC pipeline uses a **priority queue of example sets**:
- `Queue` (sc/core/example_queue.py): holds `capacity` example-sets (one set = one example per class). Updated each step by agent confidence scores — highest-confidence sets are kept, lowest evicted and replaced with a random new set.
- `AgentPool` (sc/core/agent_pool.py): runs one `SCAgent` per queue slot in parallel, aggregates via majority vote, tracks confidence and consistency.
- `SCAgent` (sc/core/scagent.py): single LLM call, returns `{REASON, CONFIDENCE, ANSWER}` JSON.
### Base Agent (core/agent.py)
- Three memory tiers: `long_term_memory`, `short_term_memory`, `volatile_memory`
- `invoke()`: async, appends messages to memory, calls model pool
- `safe_parse_json()` / `safe_parse_json_list()`: robust JSON parsing with cleanup
- Token counting via `tiktoken` (gpt-3.5-turbo encoding as proxy)
### SensingAgent (core/sensing_agent.py)
Extends Agent. Methods:
- `solve()`: classify a sample with ICL examples → `{REASON, ANSWER}`
- `interpret()`: same as solve but without logging ground truth
- `evaluate()`: evaluate another agent's answer (for multi-agent debate)
- `reflect()`: refine answer based on peer evaluations
---
## Running Experiments
### Basic ICL run
```bash
python run.py run config/sleepedf.yaml
```
### Compare multiple selection criteria
```bash
python run.py compare config/sleepedf.yaml \
--criteria_list="out_random,in_random,out_similar,in_similar" \
--embedding_path="./embeddings_full"
```
### Self-Consistency run
```bash
python -m sc.run_sc sc/config/sleepedf_sc.yaml
```
---
## Config (sleepedf.yaml) Key Fields
- `data_path`: root of dataset (contains `info.json` and per-user dirs)
- `log_path`: output directory for results and logs
- `num_seeds`: number of random seeds to run
- `num_examples`: ICL examples per class (default 1)
- `selection_criteria`: `out_random` | `in_random` | `out_similar` | `in_similar`
- `embedding_path`: path to pre-computed Chronos-2 embeddings (required for `*_similar`)
- `models`: list of model specs (Ollama URLs or model IDs)
In SC configs, additional fields:
- `queue_size`: number of example sets to maintain in the priority queue
- `temperature`: LLM sampling temperature
- `max_new_tokens`: max tokens for generation
- `example_pool`: `"out"` (other users) or `"in"` (same user)
- `continuous`: whether queue persists across test samples
---
## Data Format Details
Each HuggingFace dataset sample:
```python
{
"user_id": str,
"session_id": str, # "1" = train, "2" = test
"idx": int,
"label": str, # class name from info.json["class"]
"features": dict, # str -> float/str (sensor features formatted for prompt)
"data": dict, # optional raw data
}
```
`info.json`:
```json
{
"task": "Sleep stage classification from EEG",
"class": {"W": "Wake", "N1": "...", ...},
"feature": "EEG channel description for the LLM..."
}
```
---
## PPGBP Dataset (sc/core/data_loader.py)
`PPGBPLoader`: reads xlsx metadata + signal `.txt` files from `0_subject/`. 80/20 train/test split at subject level. Metadata embeddings (SBERT) auto-generated and cached per subject as `.npy` files under `PPGBP_METADATA_EMBEDDINGS_ROOT` (set via env var or `.env`).
---
## Dependencies (requirements.txt)
- `langchain`, `langchain_ollama`, `langchain_openai`, `langchain_together` — LLM backends
- `datasets` — HuggingFace datasets (on-disk storage)
- `chronos` — Chronos-2 time-series embeddings
- `sentence_transformers` — SBERT for metadata embeddings
- `transformers`, `torch` — HuggingFace model loading (SC pipeline)
- `tiktoken` — token counting
- `fire` — CLI argument parsing
- `mne`, `neurokit2` — EEG/biosignal preprocessing
- `numpy`, `pandas`, `scikit_learn`, `scipy`, `matplotlib`
---
## Notes / Patterns
- Experiments sample every 10th test item (`if idx % 10 != 0: continue` in `run.py`).
- SC runner currently hardcoded to first user only for testing (`users[:1]`).
- Log structure: `<log_path>/<user_id>/<sample_idx>/<seed>/` with `summary.txt`, `log.txt`, `tokens.txt`.
- Embedding index uses cosine similarity (L2-normalized dot product), filtered by user and session.
- `DataLoader` in `sc/` is simpler (no similarity selection); similarity lives in `core/`.
- `InMemoryDataLoader` and `prepare_dataset_for_sc()` allow integrating new datasets without writing to disk.

0
.claude/plan.md Normal file
View File

View File

@@ -630,10 +630,10 @@ class SBERT_Metadata:
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load saved embeddings dataset from disk.
Args:
embedding_dir: Directory path containing the saved HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
@@ -642,6 +642,94 @@ class SBERT_Metadata:
return dataset
# =============================================================================
# Save metadata embeddings by user ID (for PPGBP dataloader)
# =============================================================================
def _normalize_sex_for_sbert(sex: Any) -> Optional[int]:
"""Convert metadata sex to int for SBERT textualize (1=Female, 0=Male)."""
if sex is None or (isinstance(sex, float) and np.isnan(sex)):
return None
s = str(sex).strip().upper()
if s in ("F", "1"):
return 1
if s in ("M", "0"):
return 0
return None
def _normalize_hypertension_for_sbert(ht: Any) -> Optional[int]:
"""Convert metadata hypertension to int if numeric."""
if ht is None or (isinstance(ht, float) and np.isnan(ht)):
return None
try:
return int(float(ht))
except (ValueError, TypeError):
return None
def save_metadata_embeddings_by_userid(
data_root: str,
subject_path: str = SUBJECT_PATH,
embeddings_root: Optional[str] = None,
) -> None:
"""
Compute SBERT metadata embeddings for each user and save one file per user
under embeddings_root as {user_id}.npy (e.g. 100.npy).
Uses PPGBP_METADATA_EMBEDDINGS_ROOT from environment if embeddings_root
is not provided. Load .env from project root if present so the env var is set.
Args:
data_root: Root directory containing 0_subject/ with PPG-BP data.
subject_path: Path to PPGBP metadata xlsx (subject_ID, age, sex, etc.).
embeddings_root: Directory to write {user_id}.npy files. Defaults to
os.environ["PPGBP_METADATA_EMBEDDINGS_ROOT"].
"""
try:
from dotenv import load_dotenv
# Load .env from repo root (parent of analysis/)
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
load_dotenv(os.path.join(_repo_root, ".env"))
except ImportError:
pass
root = embeddings_root or os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT")
if not root:
raise ValueError(
"embeddings_root not provided and PPGBP_METADATA_EMBEDDINGS_ROOT not set. "
"Set the env var or pass embeddings_root."
)
os.makedirs(root, exist_ok=True)
subject_metadata = load_subject_metadata(subject_path)
embedder = SBERT_Metadata()
user_ids = sorted(subject_metadata.keys())
if not user_ids:
raise ValueError(f"No subjects found in {subject_path}")
for user_id in user_ids:
meta = subject_metadata[user_id]
sex = _normalize_sex_for_sbert(meta.get("sex"))
age = meta.get("age")
height = meta.get("height")
weight = meta.get("weight")
sbp = meta.get("sbp")
dbp = meta.get("dbp")
hr = meta.get("hr")
bmi = meta.get("bmi")
hypertension = _normalize_hypertension_for_sbert(meta.get("hypertension"))
if age is not None and isinstance(age, float) and np.isnan(age):
age = None
emb = embedder.compute_embedding_from_metadata(
[sex], [age], [height], [weight], [sbp], [dbp], [hr], [bmi], [hypertension]
)
path = os.path.join(root, f"{user_id}.npy")
np.save(path, emb[0])
print(f"[DONE] Saved {len(user_ids)} metadata embeddings under {root}")
# =============================================================================
# Data Loading Utilities
# =============================================================================

1029
sample_case.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,32 @@
# PPGBP Recruiter (MAB-based example set selection) config
# Usage: python -m sc.run_recruiter sc/config/ppgbp_recruiter.yaml
# When model_path is null: use ground-truth scorer (arm_id -> score), no LLM; run gt_steps to test MAB convergence.
data_path: $PPGBP_DATA_ROOT
feature_root: $DATA_INDEX_ROOT/metadata_onehot # index.json + embeddings.npy
log_path: ./logs/ppgbp_recruiter
model_path: null # set to a HF model path to use real LLM + self_certainty
queue_size: 5
recruit_size: 2
n_way: 2
k_shot: 3
num_arms: 50
# When model_path is null: deterministic reward = Gaussian over arm index, peak at gt_best_arm_id
gt_steps: 200
gt_best_arm_id: 20
gt_sigma: 5.0
n_das: 5
delta: 0.05
epsilon: 0.11
temperature: 0.7
max_new_tokens: 256
num_seeds: 1
seed: 42
task_info: "Classify the subject from PPG-BP metadata."
classes_info: ["Normal", "Prehypertension", "Stage 1 hypertension", "Stage 2 hypertension" ]

View File

@@ -7,17 +7,44 @@ Expects the following directory structure:
<path>/
info.json # dataset metadata (task, classes, features)
<user_id>/
1/ # examples (train split)
2/ # test split
1/ # examples (train split) - HuggingFace dataset on disk
2/ # test split - HuggingFace dataset on disk
<other_user>/
1/
...
--------------------------------------------------------------------------------
Integrating a new dataset
--------------------------------------------------------------------------------
The pipeline expects:
1) Directory layout (for DataLoader / ShuffledDataLoader):
- <data_path>/info.json
- <data_path>/<user_id>/1/ and <data_path>/<user_id>/2/ (HuggingFace datasets saved with datasets.Dataset.save_to_disk)
2) info.json schema:
- "task": str (e.g. "Sleep stage classification from EEG")
- "class": dict mapping label -> description (e.g. {"W": "Wake", "N1": "Non-REM 1", ...})
- "feature": str (sensor/feature description for the LLM)
3) Each sample (in test and example datasets) must be a dict with:
- "label": str (class name, must be one of the keys in info.json["class"])
- "features": dict of str -> value (e.g. {"Fpz-Cz": "0.12, -0.05, ...", "Pz-Oz": "..."}); values are formatted for the prompt
Option A - Convert your dataset to this format:
- Create info.json and per-user dirs; build HuggingFace Dataset with columns "label" and "features", then save_to_disk to <user_id>/1 and <user_id>/2.
- Use prepare_dataset_for_sc() in this module to write from in-memory lists.
Option B - Use in-memory data without changing directory layout:
- Use InMemoryDataLoader(metadata, test_samples, example_samples) which implements the same interface as DataLoader.
- Use it in run_sc by constructing it and passing to run_single_task; for run_confidence_based / run_consistency_based / run_sc_queue_random you currently need the directory layout (or extend those runners to accept a loader instance).
"""
import os
import json
from glob import glob
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Union
import datasets
@@ -26,10 +53,20 @@ class DataLoader:
"""Loads a target user's test set and ICL examples from other users."""
def __init__(
self, path: str, user_id: str, shuffle: bool = False, seed: int = 0
self,
path: Optional[str] = None,
user_id: Optional[str] = None,
shuffle: bool = False,
seed: int = 0,
data_path: Optional[str] = None,
**kwargs: Any,
) -> None:
metadata_path = os.path.join(path, "info.json")
target_user_path = os.path.join(path, user_id, "2")
# Support both path= and data_path= for compatibility with run_sc
root = path or data_path
if root is None or user_id is None:
raise ValueError("path (or data_path) and user_id are required")
metadata_path = os.path.join(root, "info.json")
target_user_path = os.path.join(root, user_id, "2")
if not os.path.exists(metadata_path) or not os.path.exists(target_user_path):
return
@@ -44,7 +81,7 @@ class DataLoader:
# Example set: train splits from all *other* users
example_paths = [
os.path.join(user_path, "1")
for user_path in glob(os.path.join(path, "*"))
for user_path in glob(os.path.join(root, "*"))
if os.path.isdir(user_path) and os.path.basename(user_path) != user_id
]
example_parts = [datasets.load_from_disk(p) for p in example_paths]
@@ -87,6 +124,88 @@ class DataLoader:
return list(self.metadata["class"].keys())
class InMemoryDataLoader:
"""
DataLoader interface backed by in-memory lists. Use this to run the SC
pipeline on a new dataset without writing the directory layout to disk.
Implements the same interface as DataLoader: __len__, __getitem__, __iter__,
get_examples(), get_metadata(), get_sensor_info(), get_task_info(), get_classes_info().
"""
def __init__(
self,
metadata: Dict[str, Any],
test_samples: List[Dict[str, Any]],
example_samples: List[Dict[str, Any]],
) -> None:
"""
Args:
metadata: Must have "task", "class" (dict label -> description), "feature" (str).
test_samples: List of dicts with "label" and "features" (dict).
example_samples: Same schema; used as ICL example pool.
"""
self.metadata = metadata
self.test_dataset = datasets.Dataset.from_list(test_samples)
self.example_dataset = datasets.Dataset.from_list(example_samples)
def __len__(self) -> int:
return len(self.test_dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
return self.test_dataset[idx]
def __iter__(self):
yield from self.test_dataset
def get_examples(self) -> datasets.Dataset:
return self.example_dataset
def get_metadata(self) -> Dict[str, Any]:
return self.metadata
def get_sensor_info(self) -> str:
return self.metadata["feature"]
def get_task_info(self) -> str:
classes_info = "\n".join(
f" - {k}: {v}" for k, v in self.metadata["class"].items()
)
return f"**Task**:\n{self.metadata['task']}\n\n**Classes**:\n{classes_info}"
def get_classes_info(self) -> List[str]:
return list(self.metadata["class"].keys())
def prepare_dataset_for_sc(
output_path: str,
metadata: Dict[str, Any],
per_user_splits: Dict[str, Dict[str, List[Dict[str, Any]]]],
) -> None:
"""
Write a new dataset to the directory layout expected by DataLoader and
ShuffledDataLoader. Call this once; then point config data_path to output_path.
Args:
output_path: Root directory to create (e.g. /path/to/MyDataset).
metadata: info.json contents: "task", "class" (dict), "feature" (str).
per_user_splits: { user_id: { "train": [samples], "test": [samples] } }.
Each sample is a dict with "label" and "features".
"""
os.makedirs(output_path, exist_ok=True)
with open(os.path.join(output_path, "info.json"), "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
for user_id, splits in per_user_splits.items():
user_dir = os.path.join(output_path, str(user_id))
os.makedirs(os.path.join(user_dir, "1"), exist_ok=True)
os.makedirs(os.path.join(user_dir, "2"), exist_ok=True)
train_ds = datasets.Dataset.from_list(splits.get("train", []))
test_ds = datasets.Dataset.from_list(splits.get("test", []))
train_ds.save_to_disk(os.path.join(user_dir, "1"))
test_ds.save_to_disk(os.path.join(user_dir, "2"))
# PPGBP data loader
"""
@@ -94,13 +213,160 @@ PPGBP Dataset Loader
A data loader/iterator for the PPG-BP dataset that reads xlsx metadata
and corresponding signal text files, with optional batching support.
Metadata embeddings live under PPGBP_METADATA_EMBEDDINGS_ROOT (one .npy per
subject ID). If missing, they are generated automatically in PPGBPLoader.__init__
using SBERT. Set PPGBP_METADATA_EMBEDDINGS_ROOT in .env or environment.
"""
import os
import sys
import numpy as np
import pandas as pd
from typing import Iterator, Tuple, Dict, List, Optional, Union
try:
from dotenv import load_dotenv
_ppgbp_env_loaded = False
def _ensure_ppgbp_env():
global _ppgbp_env_loaded
if not _ppgbp_env_loaded:
load_dotenv()
_ppgbp_env_loaded = True
except ImportError:
def _ensure_ppgbp_env():
pass
def _get_ppgbp_embedding_root() -> str:
_ensure_ppgbp_env()
root = os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT")
if not root:
raise ValueError(
"PPGBP_METADATA_EMBEDDINGS_ROOT is not set. Set it in .env or environment "
"(e.g. PPGBP_METADATA_EMBEDDINGS_ROOT=/path/to/metadata_embeddings)."
)
return root
def _get_ppgbp_embedding_root_optional() -> Optional[str]:
"""Return PPGBP_METADATA_EMBEDDINGS_ROOT if set, else None. Used when embeddings are not required (e.g. recruiter with feature_root)."""
_ensure_ppgbp_env()
root = os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT") or ""
return root.strip() or None
def _get_missing_embedding_ids(embedding_root: str, subject_ids: np.ndarray) -> Optional[List[int]]:
"""
Return list of subject_ids that do not have a .npy file in embedding_root.
Return None if embedding_root is not a directory (so caller can create and generate).
"""
if not os.path.isdir(embedding_root):
return None
missing = []
for sid in subject_ids:
path = os.path.join(embedding_root, f"{int(sid)}.npy")
if not os.path.isfile(path):
missing.append(int(sid))
return missing if missing else []
def _check_embedding_root_populated(embedding_root: str, subject_ids: np.ndarray) -> None:
"""Raise if embedding_root is missing or does not contain .npy for every subject_id."""
missing_or_none = _get_missing_embedding_ids(embedding_root, subject_ids)
if missing_or_none is None:
raise FileNotFoundError(
f"Metadata embeddings root is not a directory or does not exist: {embedding_root}."
)
if missing_or_none:
raise FileNotFoundError(
f"Metadata embeddings root {embedding_root} is missing .npy files for subject IDs: "
f"{missing_or_none[:10]}{'...' if len(missing_or_none) > 10 else ''}."
)
def _generate_and_save_ppgbp_metadata_embeddings(
embedding_root: str,
subject_path: str,
subject_ids: np.ndarray,
) -> None:
"""
Generate SBERT metadata embeddings for the given subject IDs and save as
{subject_id}.npy under embedding_root. Uses tqdm for progress.
"""
try:
from tqdm import tqdm
except ImportError:
tqdm = lambda x, **kw: x # noqa: E731
# Lazy import to keep SBERT/gen_plot deps out of normal import path
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from analysis.user_similarity.sbert_metadata_ppgbp.gen_plot import (
load_subject_metadata,
SBERT_Metadata,
_normalize_sex_for_sbert,
_normalize_hypertension_for_sbert,
)
os.makedirs(embedding_root, exist_ok=True)
subject_metadata = load_subject_metadata(subject_path)
embedder = SBERT_Metadata()
ids_to_generate = [int(sid) for sid in subject_ids]
for user_id in tqdm(ids_to_generate, desc="PPGBP metadata embeddings", unit="subject"):
user_id_str = str(user_id)
meta = subject_metadata.get(user_id_str, {})
sex = _normalize_sex_for_sbert(meta.get("sex"))
age = meta.get("age")
height = meta.get("height")
weight = meta.get("weight")
sbp = meta.get("sbp")
dbp = meta.get("dbp")
hr = meta.get("hr")
bmi = meta.get("bmi")
hypertension = _normalize_hypertension_for_sbert(meta.get("hypertension"))
if age is not None and isinstance(age, float) and np.isnan(age):
age = None
emb = embedder.compute_embedding_from_metadata(
[sex], [age], [height], [weight], [sbp], [dbp], [hr], [bmi], [hypertension]
)
path = os.path.join(embedding_root, f"{user_id}.npy")
np.save(path, emb[0])
def load_ppgbp_metadata_embedding(embedding_root: str, subject_id: int) -> np.ndarray:
"""Load a single metadata embedding for subject_id from embedding_root (utility)."""
path = os.path.join(embedding_root, f"{int(subject_id)}.npy")
if not os.path.isfile(path):
raise FileNotFoundError(f"Metadata embedding not found: {path}")
return np.load(path)
def save_ppgbp_metadata_embeddings(
embedding_root: str,
subject_path: str,
subject_ids: Optional[np.ndarray] = None,
) -> None:
"""
Generate and save PPGBP metadata embeddings under embedding_root (one .npy per subject).
If subject_ids is None, uses all subject IDs present in the metadata file at subject_path.
Uses tqdm for progress.
"""
if subject_ids is not None and len(subject_ids) == 0:
return
# If subject_ids is None we need to get all from metadata; _generate_* expects an array
if subject_ids is None:
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from analysis.user_similarity.sbert_metadata_ppgbp.gen_plot import load_subject_metadata
subject_metadata = load_subject_metadata(subject_path)
subject_ids = np.array(sorted(int(k) for k in subject_metadata.keys()))
_generate_and_save_ppgbp_metadata_embeddings(embedding_root, subject_path, subject_ids)
class PPGBPLoader:
"""
@@ -128,7 +394,8 @@ class PPGBPLoader:
batch_size: Optional[int] = None,
shuffle: bool = False,
num_segments: int = 3,
seed: int = 42
seed: int = 42,
return_metadata_embeddings: bool = False,
):
self.base_dir = base_dir
self.split = split
@@ -136,6 +403,7 @@ class PPGBPLoader:
self.shuffle = shuffle
self.num_segments = num_segments
self.seed = seed
self.return_metadata_embeddings = return_metadata_embeddings
self.rng = np.random.default_rng(seed)
self.signal_dir = os.path.join(base_dir, "0_subject")
@@ -147,6 +415,15 @@ class PPGBPLoader:
self.subject_ids = self._get_split_ids()
self.metadata = self.metadata[self.metadata['subject_ID'].isin(self.subject_ids)]
self._embedding_root = _get_ppgbp_embedding_root_optional()
if self._embedding_root is not None:
missing = _get_missing_embedding_ids(self._embedding_root, self._all_subject_ids)
if len(self._all_subject_ids) > 0 and (missing is None or len(missing) > 0):
_generate_and_save_ppgbp_metadata_embeddings(
self._embedding_root, self.xlsx_path, self._all_subject_ids
)
_check_embedding_root_populated(self._embedding_root, self.subject_ids)
self._current_idx = 0
self._indices = np.arange(len(self.subject_ids))
@@ -229,6 +506,9 @@ class PPGBPLoader:
row = self.metadata[self.metadata['subject_ID'] == subject_id].iloc[0]
return row.to_dict()
def _load_metadata_embedding(self, subject_id: int) -> np.ndarray:
return load_ppgbp_metadata_embedding(self._embedding_root, subject_id)
def __len__(self) -> int:
return len(self.subject_ids)
@@ -238,7 +518,12 @@ class PPGBPLoader:
self.rng.shuffle(self._indices)
return self
def __next__(self) -> Union[Tuple[Dict, np.ndarray], Tuple[List[Dict], List[np.ndarray]]]:
def __next__(self) -> Union[
Tuple[Dict, np.ndarray],
Tuple[Dict, np.ndarray, np.ndarray],
Tuple[List[Dict], List[np.ndarray]],
Tuple[List[Dict], List[np.ndarray], List[np.ndarray]],
]:
if self._current_idx >= len(self.subject_ids):
raise StopIteration
@@ -248,6 +533,8 @@ class PPGBPLoader:
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
self._current_idx += 1
if self.return_metadata_embeddings and self._embedding_root is not None:
return metadata, signal, self._load_metadata_embedding(subject_id)
return metadata, signal
else:
end_idx = min(self._current_idx + self.batch_size, len(self.subject_ids))
@@ -255,28 +542,37 @@ class PPGBPLoader:
metadata_list = []
signal_list = []
embedding_list = [] if (self.return_metadata_embeddings and self._embedding_root is not None) else None
for idx in batch_indices:
subject_id = self.subject_ids[idx]
metadata_list.append(self._get_metadata_dict(subject_id))
signal_list.append(self._load_signal(subject_id))
if self.return_metadata_embeddings and self._embedding_root is not None:
embedding_list.append(self._load_metadata_embedding(subject_id))
self._current_idx = end_idx
if embedding_list is not None:
return metadata_list, signal_list, embedding_list
return metadata_list, signal_list
def __getitem__(self, idx: int) -> Tuple[Dict, np.ndarray]:
def __getitem__(self, idx: int) -> Union[Tuple[Dict, np.ndarray], Tuple[Dict, np.ndarray, np.ndarray]]:
if idx < 0 or idx >= len(self.subject_ids):
raise IndexError(f"Index {idx} out of range")
subject_id = self.subject_ids[idx]
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
if self.return_metadata_embeddings and self._embedding_root is not None:
return metadata, signal, self._load_metadata_embedding(subject_id)
return metadata, signal
def get_by_subject_id(self, subject_id: int) -> Tuple[Dict, np.ndarray]:
def get_by_subject_id(self, subject_id: int) -> Union[Tuple[Dict, np.ndarray], Tuple[Dict, np.ndarray, np.ndarray]]:
if subject_id not in self.subject_ids:
raise ValueError(f"Subject ID {subject_id} not found in this split.")
metadata = self._get_metadata_dict(subject_id)
signal = self._load_signal(subject_id)
if self.return_metadata_embeddings and self._embedding_root is not None:
return metadata, signal, self._load_metadata_embedding(subject_id)
return metadata, signal
def reset(self):
@@ -284,7 +580,9 @@ class PPGBPLoader:
if self.shuffle:
self.rng.shuffle(self._indices)
def iter_batches(self, batch_size: int) -> Iterator[Tuple[List[Dict], List[np.ndarray]]]:
def iter_batches(
self, batch_size: int
) -> Iterator[Union[Tuple[List[Dict], List[np.ndarray]], Tuple[List[Dict], List[np.ndarray], List[np.ndarray]]]]:
indices = np.arange(len(self.subject_ids))
if self.shuffle:
self.rng.shuffle(indices)
@@ -295,13 +593,19 @@ class PPGBPLoader:
metadata_list = []
signal_list = []
embedding_list = [] if (self.return_metadata_embeddings and self._embedding_root is not None) else None
for idx in batch_indices:
subject_id = self.subject_ids[idx]
metadata_list.append(self._get_metadata_dict(subject_id))
signal_list.append(self._load_signal(subject_id))
if self.return_metadata_embeddings and self._embedding_root is not None:
embedding_list.append(self._load_metadata_embedding(subject_id))
yield metadata_list, signal_list
if embedding_list is not None:
yield metadata_list, signal_list, embedding_list
else:
yield metadata_list, signal_list
def get_loaders(

View File

@@ -151,4 +151,38 @@ class Queue:
"avg_survival": sum(durations) / len(durations),
"max_survival": max(durations),
"avg_usage": sum(usages) / len(usages) if usages else 0
}
}
def update_with_recruiter(self, results, new_example_sets, L: int = 1):
"""
Evict L lowest-scoring sets (by self-certainty from results) and push L new sets.
results: list of (case, response_dict_or_text, self_certainty_score) in queue order.
new_example_sets: list of L new cases (each case = list of indices, same as queue items).
Lower self_certainty = better; we evict the L with highest (worst) score.
"""
if not results or L <= 0:
return
current = list(self._queue)
if len(current) == 0:
for case in new_example_sets[: self._queue.maxlen]:
self._queue.append(list(case))
self._register_stats(case)
return
scores = []
for i, item in enumerate(current):
sc = -1.0
if i < len(results) and len(results[i]) >= 3:
sc = float(results[i][2])
scores.append((i, sc))
scores.sort(key=lambda x: x[1], reverse=True)
to_evict_idx = {scores[j][0] for j in range(min(L, len(scores)))}
new_queue = [current[i] for i in range(len(current)) if i not in to_evict_idx]
for idx in sorted(to_evict_idx, reverse=True):
if idx < len(current):
self._record_eviction_stats(current[idx])
for case in new_example_sets[:L]:
if len(new_queue) >= self._queue.maxlen:
break
new_queue.append(list(case))
self._register_stats(case)
self._queue = deque(new_queue, maxlen=self._queue.maxlen)

487
sc/core/recruiter_agent.py Normal file
View File

@@ -0,0 +1,487 @@
"""
RecruiterAgent: MAB-based example set selection for personalized ICL.
Components:
- self_certainty(logits): KL-div of LLM output from uniform (reward signal)
- borda_vote(results): aggregate answers across example sets
- FeatureIndex: loads index.json + embeddings.npy for O(1) lookup
- MABSelector: contextual bandit adapted from CASE/sample_case.py top_m_arm
- ArmPool: manages candidate example sets and feature vectors
- RecruiterAgent: orchestrates everything
Feature file schema:
<feature_root>/index.json + <feature_root>/embeddings.npy
index.json:
{
"version": "1.0",
"embedding_type": "sbert_metadata",
"embedding_dim": 384,
"samples": [
{"row": 0, "user_id": "001", "sample_id": 42, "label": "Hypertension", "session": "1"},
...
]
}
embeddings.npy: float32 array shape (N_samples, embedding_dim)
Switch embedding type by pointing to a different feature_root directory.
"""
from __future__ import annotations
import collections, json, math, os, random
from random import sample as random_sample
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
try:
import torch
import torch.nn.functional as F
_TORCH = True
except ImportError:
_TORCH = False
# ── Utilities ──────────────────────────────────────────────────────────────────
def self_certainty(logits) -> float:
"""KL divergence of LLM logits from uniform. Lower = more certain. Negate for MAB."""
if logits is None or not _TORCH:
return 0.0
logits = logits.squeeze(0).float()
log_probs = F.log_softmax(logits, dim=-1)
score = (-1.0 / logits.shape[-1]) * log_probs.sum(dim=-1) - math.log(logits.shape[-1])
return score.mean().item()
def borda_vote(
results: List[Tuple[int, Optional[Dict[str, Any]], float]],
valid_classes: Optional[List[str]] = None,
) -> Optional[str]:
"""
Borda-count vote over (arm_idx, response_dict, certainty_score).
Lower certainty_score = more certain = higher Borda rank.
"""
valid = [
(r[1]["ANSWER"], r[2]) for r in results
if r[1] is not None and "ANSWER" in r[1]
and (valid_classes is None or r[1]["ANSWER"] in valid_classes)
]
if not valid:
return None
sorted_v = sorted(valid, key=lambda x: x[1])
M = len(sorted_v)
borda: Dict[str, float] = {}
for rank, (ans, _) in enumerate(sorted_v):
borda[ans] = borda.get(ans, 0.0) + (M - rank)
top = max(borda.values())
cands = [a for a, s in borda.items() if s == top]
if len(cands) == 1:
return cands[0]
means = {a: float(np.mean([sc for x, sc in valid if x == a])) for a in cands}
return min(means, key=means.get)
# ── FeatureIndex ───────────────────────────────────────────────────────────────
class FeatureIndex:
"""Loads embeddings from disk; O(1) lookup by (user_id, sample_id)."""
def __init__(self, feature_root: str) -> None:
idx_p = os.path.join(feature_root, "index.json")
emb_p = os.path.join(feature_root, "embeddings.npy")
if not os.path.isfile(idx_p):
raise FileNotFoundError(f"index.json not found: {idx_p}")
if not os.path.isfile(emb_p):
raise FileNotFoundError(f"embeddings.npy not found: {emb_p}")
with open(idx_p, "r", encoding="utf-8") as f:
meta = json.load(f)
self.embedding_type: str = meta.get("embedding_type", "unknown")
self.embedding_dim: int = meta["embedding_dim"]
self.embeddings: np.ndarray = np.load(emb_p).astype(np.float32)
self.samples: List[Dict[str, Any]] = meta["samples"]
self._lookup: Dict[Tuple[str, int], int] = {
(str(e["user_id"]), int(e["sample_id"])): int(e["row"])
for e in self.samples
}
print(f"[FeatureIndex] {len(self.embeddings)} embeddings "
f"(dim={self.embedding_dim}, type={self.embedding_type})")
def get(self, user_id: str, sample_id: int) -> Optional[np.ndarray]:
row = self._lookup.get((str(user_id), int(sample_id)))
return self.embeddings[row] if row is not None else None
def get_user_embedding(self, user_id: str) -> Optional[np.ndarray]:
"""Mean embedding across all samples for a user."""
rows = [self.embeddings[e["row"]] for e in self.samples
if str(e["user_id"]) == str(user_id)]
return np.mean(rows, axis=0).astype(np.float32) if rows else None
# ── MABSelector ────────────────────────────────────────────────────────────────
class MABSelector:
"""
Contextual bandit for top-m arm identification.
Adapted from CASE / sample_case.py top_m_arm.
Key differences from original:
- Rewards injected via observe(arm, reward) — no LLM calls, no file IO
- update() does one rank-1 theta update per call
- No initialization() warm-up loop required
X: np.ndarray shape (embedding_dim, num_arms)
m: number of top arms to track (= queue_size)
"""
def __init__(
self, X: np.ndarray, m: int,
n_das: int = 5, delta: float = 0.05,
epsilon: float = 0.11, sigma: float = 0.5,
) -> None:
assert X.ndim == 2, "X must be (embedding_dim, num_arms)"
self.X = X.astype(np.float64)
self.N, self.K = X.shape
self.m = min(m, self.K)
self.n_das = n_das
self.delta = delta
self.epsilon = epsilon
self.sigma = sigma
self.arms = list(range(self.K))
self._it = -1
self._Bdi: Dict[Tuple[int, int], float] = {}
self._reset()
def _reset(self) -> None:
self.t = 0
self.rewards: List[float] = []
self.pulled_arms: List[int] = []
self.means = np.zeros(self.K)
self.na = np.zeros(self.K)
self.B_inv = np.eye(self.N, dtype=np.float64)
self.b = np.zeros(self.N, dtype=np.float64)
self.theta = np.random.normal(0, 1, size=(1, self.N))
self.J: List[int] = []
self.notJ: List[int] = list(range(self.K))
self.N_t: List[int] = []
self.best_arm: Optional[int] = None
self.worst_arm: Optional[int] = None
self.challenger: Optional[int] = None
self.c_t: Optional[int] = None
self.patience = 0
self.previous_J: List[int] = []
# ── math helpers ────────────────────────────────────────────────────────────
def _mn(self, x: np.ndarray, A: np.ndarray) -> float:
x = np.asarray(x).flatten()
return float(np.sqrt(max(float(x @ A @ x), 0.0)))
def _iu(self, A: np.ndarray, x: np.ndarray) -> np.ndarray:
"""Rank-1 inverse update: A - A x xT A / (1 + xT A x)."""
x = np.asarray(x).flatten()
d = 1.0 + self._mn(x, A) ** 2
return A - np.outer(A @ x, x @ A) / d
def _beta(self) -> float:
return math.log((math.log(max(self.t, 2)) + 1) / max(self.delta, 1e-12))
def _gap(self, i: int, j: Optional[int] = None) -> float:
xi = self.X[:, i] - (self.X[:, j] if j is not None else 0.0)
return float(self.theta.flatten() @ xi)
def _var(self, i: int, j: Optional[int] = None) -> float:
S = (self.sigma ** 2) * self.B_inv
xi = self.X[:, i] - (self.X[:, j] if j is not None else 0.0)
return self._mn(xi, S) * math.sqrt(2 * self._beta())
def _Bij(self, i: int, j: int) -> float:
t = max(self.t, 1)
if t != self._it:
self._Bdi = {}
self._it = t
k = (i, j)
if k not in self._Bdi:
self._Bdi[k] = self._gap(i, j) + self._var(i, j)
return self._Bdi[k]
def _randf(self, x: List[float], f) -> int:
arr = np.array(x, dtype=float)
val = f(arr)
c = np.argwhere(arr == val).flatten().tolist()
return random_sample(c, 1)[0]
def _mmax(self, x: List[float], m: int) -> List[int]:
x = list(x)
ids = []
for _ in range(min(m, len(x))):
idx = self._randf(x, np.max)
ids.append(idx)
x[idx] = -float("inf")
return ids
# ── sample step ─────────────────────────────────────────────────────────────
def _sel(self, nt: Optional[int]) -> Optional[int]:
if nt is None or self.c_t is None:
return None
if self.means[self.c_t] > self.means[nt]:
s = self.c_t
if s in self.N_t:
self.N_t.remove(s)
return s
return None
def _Jt(self) -> List[int]:
if self.t == 0:
return self._mmax(self.means.tolist(), self.m)
nt = self.worst_arm
sel = self._sel(nt)
self.previous_J = list(self.J)
if sel is not None:
self.J = [a for a in self.J if a != nt] + [sel]
if nt is not None and nt not in self.N_t:
self.N_t.append(nt)
order = np.argsort(self.means[self.J])[::-1].tolist()
self.J = [self.J[i] for i in order]
self.patience = 0 if self.J != self.previous_J else self.patience + 1
return self.J
def sample(self) -> List[int]:
"""Select next arm(s) to evaluate. Returns list of arm indices."""
self.J = self._Jt()
self.notJ = [a for a in self.arms if a not in self.J]
if self.t == 0 or not self.N_t:
n = min(self.n_das, len(self.notJ))
self.N_t = list(np.random.choice(self.notJ, n, replace=False)) if self.notJ else []
else:
Qt = list(np.random.choice(self.notJ, min(self.n_das, len(self.notJ)), replace=False)) if self.notJ else []
self.N_t = list({*self.N_t, *Qt} - set(self.J))
if self.N_t:
top_n = min(self.n_das, len(self.N_t))
tv = np.sort(self.means[self.N_t])[::-1][:top_n]
picked: List[int] = []
for mv in tv:
c = [a for a in self.N_t if self.means[a] == mv and a not in picked]
if c:
picked.append(random.choice(c))
self.N_t = picked
if self.N_t and self.J:
jm = self.means[self.J]
ntc = [a for a in self.J if self.means[a] == jm.min()]
self.worst_arm = random.choice(ntc)
bti = [self._Bij(a, self.worst_arm) for a in self.J]
self.best_arm = self.J[self._randf(bti, np.max)]
chi = [self._Bij(a, self.best_arm) for a in self.N_t]
self.challenger = self.N_t[self._randf(chi, np.max)]
nm = self.means[self.N_t]
c = [a for a in self.N_t if self.means[a] == nm.max()]
self.c_t = random.choice(c)
else:
self.worst_arm = self.J[0] if self.J else None
self.best_arm = self.challenger = self.c_t = None
# Greedy arm pull
if self.best_arm is not None and self.challenger is not None:
d = self.X[:, self.best_arm] - self.X[:, self.challenger]
u = [self._mn(d, self._iu(self.B_inv, self.X[:, i])) for i in self.arms]
return [self.arms[self._randf(u, np.min)]]
return [random.choice(self.arms)]
def observe(self, arm: int, reward: float) -> None:
"""Record reward for arm."""
self.rewards.append(reward)
self.pulled_arms.append(arm)
self.na[arm] += 1
self.t += 1
def update(self) -> None:
"""Rank-1 theta/B_inv update from the last observation."""
if not self.pulled_arms:
return
x = self.X[:, self.pulled_arms[-1]].flatten()
self.B_inv = self._iu(self.B_inv, x)
self.b += self.rewards[-1] * x
self.theta = (self.B_inv @ self.b).reshape(1, self.N)
self.means = (self.theta @ self.X).flatten()
def recommend(self, n: int) -> List[int]:
"""Return n arm indices with highest estimated mean rewards."""
return self._mmax(self.means.tolist(), min(n, self.K))
def stopping_rule(self) -> bool:
if None in (self.best_arm, self.worst_arm, self.challenger):
return False
return round(self._Bij(self.challenger, self.best_arm), 2) <= self.epsilon
# ── ArmPool ────────────────────────────────────────────────────────────────────
#[TODO] Add way to sample arms from the larger pool with replacement, instead of fixing on to num_arms arms.
class ArmPool:
"""
Manages candidate example sets (arms).
Each arm = list of sample dicts (N-way x k_shot).
Arm feature = mean embedding of its constituent samples.
"""
def __init__(
self,
source_samples: List[Dict[str, Any]],
classes: List[str],
k_shot: int,
feature_index: Optional[FeatureIndex],
num_arms: int,
rng: np.random.Generator,
) -> None:
self.classes = classes
self.k_shot = k_shot
self.feature_index = feature_index
self._rng = rng
self._by: Dict[str, List[Dict[str, Any]]] = {c: [] for c in classes}
for s in source_samples:
lbl = s.get("label")
if lbl in self._by:
self._by[lbl].append(s)
self.arms: List[List[Dict[str, Any]]] = []
self.arm_feats: List[Optional[np.ndarray]] = []
for _ in range(num_arms):
arm, feat = self._make()
self.arms.append(arm)
self.arm_feats.append(feat)
print(f"[ArmPool] {len(self.arms)} arms ({len(classes)}-way {k_shot}-shot)")
def _make(self) -> Tuple[List[Dict[str, Any]], Optional[np.ndarray]]:
s: List[Dict[str, Any]] = []
for c in self.classes:
pool = self._by.get(c, [])
if pool:
n = min(self.k_shot, len(pool))
idxs = self._rng.choice(len(pool), size=n, replace=False)
s.extend([pool[i] for i in idxs])
return s, self._feat(s)
def _feat(self, samples: List[Dict[str, Any]]) -> Optional[np.ndarray]:
if self.feature_index is None:
return None
vecs = []
for s in samples:
v = self.feature_index.get(
str(s.get("user_id", "")),
int(s.get("sample_id", s.get("idx", -1))),
)
if v is not None:
vecs.append(v)
return np.mean(vecs, axis=0).astype(np.float32) if vecs else None
def get_arm(self, i: int) -> List[Dict[str, Any]]:
return self.arms[i]
def build_X(self) -> np.ndarray:
"""Feature matrix X of shape (embedding_dim, num_arms). Falls back to identity."""
feats = [f for f in self.arm_feats if f is not None]
if not feats:
return np.eye(len(self.arms), dtype=np.float32)
X = np.zeros((feats[0].shape[0], len(self.arms)), dtype=np.float32)
for i, f in enumerate(self.arm_feats):
if f is not None:
X[:, i] = f
return X
# ── RecruiterAgent ────────────────────────────────────────────────────────────
class RecruiterAgent:
"""
MAB-driven recruiter for personalized ICL example set selection.
Usage:
recruiter = RecruiterAgent(source_samples, classes, ...)
initial = recruiter.recruit(queue_size) # cold start (random)
for test_sample in target_stream:
results = [(arm_idx, response_dict, score), ...]
recruiter.update(results) # MAB update lives here
new_sets = recruiter.recruit(L) # MAB-guided new sets
queue.update_with_recruiter(results, new_sets)
MAB update logic inside update():
1. reward = -score (negate: lower certainty = better = higher reward)
2. mab.observe(arm, reward)
3. mab.update() -> rank-1 theta/B_inv update
4. mab.sample() -> select next arms to explore (stored for next recruit())
"""
def __init__(
self,
source_samples: List[Dict[str, Any]],
classes: List[str],
feature_index: Optional[FeatureIndex] = None,
target_user_id: Optional[str] = None,
k_shot: int = 1,
queue_size: int = 5,
num_arms: int = 50,
n_das: int = 5,
delta: float = 0.05,
epsilon: float = 0.11,
sigma: float = 0.5,
seed: int = 42,
) -> None:
self.classes = classes
self.k_shot = k_shot
self.queue_size = queue_size
self.target_user_id = target_user_id
rng = np.random.default_rng(seed)
random.seed(seed)
np.random.seed(seed)
self.pool = ArmPool(source_samples, classes, k_shot, feature_index, num_arms, rng)
self.mab = MABSelector(
self.pool.build_X(), m=queue_size,
n_das=n_das, delta=delta, epsilon=epsilon, sigma=sigma,
)
self._next: Optional[List[int]] = None
self._initialized = False
print(f"[RecruiterAgent] {len(self.pool.arms)} arms, "
f"queue={queue_size}, classes={list(set(classes))}")
def recruit(self, n: int) -> List[Tuple[int, List[Dict[str, Any]]]]:
"""
Return n (arm_idx, example_set) pairs for insertion into the queue.
Cold start: random. Subsequent calls: MAB-recommended.
"""
if not self._initialized:
indices = random.sample(range(len(self.pool.arms)), min(n, len(self.pool.arms)))
self._initialized = True
elif self._next is not None:
indices = self._next[:n]
if len(indices) < n:
rest = [i for i in range(len(self.pool.arms)) if i not in indices]
indices += random.sample(rest, min(n - len(indices), len(rest)))
else:
indices = self.mab.recommend(n)
return [(i, self.pool.get_arm(i)) for i in indices]
def update(self, results: List[Tuple[int, Optional[Dict[str, Any]], float]]) -> None:
"""
Update MAB from (arm_idx, response_dict, self_certainty_score).
Steps:
1. reward = -score (lower certainty = better = higher reward)
2. mab.observe(arm, reward)
3. mab.update() -> rank-1 theta/B_inv update
4. mab.sample() -> selects next arms to explore
"""
if not results:
return
for arm_idx, _, score in results:
if 0 <= arm_idx < len(self.pool.arms):
self.mab.observe(arm_idx, -float(score))
self.mab.update()
try:
next_cands = self.mab.sample()
top = self.mab.recommend(self.queue_size)
self._next = list(dict.fromkeys(next_cands + top))
except Exception as e:
print(f"[RecruiterAgent] MAB sample error ({e}), using recommend()")
self._next = self.mab.recommend(self.queue_size)

View File

@@ -0,0 +1,99 @@
"""
Build unified feature index from PPGBP per-subject .npy embeddings.
Output: <output_dir>/index.json + <output_dir>/embeddings.npy
Usage:
python -m sc.preprocess.build_feature_index --embedding_root $ROOT --output_dir ./features/sbert_metadata
"""
from __future__ import annotations
import argparse
import json
import os
from typing import Any, Dict, List, Optional
import numpy as np
def _load_labels(base_dir: Optional[str]) -> Dict[int, str]:
"""Load subject_id -> label from xlsx. Prefers Hypertension column (PPGBP), else class/Class."""
if not base_dir:
return {}
try:
import pandas as pd
except ImportError:
return {}
xlsx = None
for f in os.listdir(base_dir):
if f.endswith(".xlsx"):
xlsx = os.path.join(base_dir, f)
break
if not xlsx or not os.path.isfile(xlsx):
return {}
df = pd.read_excel(xlsx, header=1)
if "subject_ID" not in df.columns:
return {}
label_col = None
for c in ("Hypertension", "class", "Class"):
if c in df.columns:
label_col = c
break
out = {}
for _, row in df.iterrows():
sid = int(row["subject_ID"])
out[sid] = str(row.get(label_col, "unknown")).strip() if label_col else "unknown"
return out
def build_index(
embedding_root: str,
output_dir: str,
embedding_type: str = "sbert_metadata",
base_dir: Optional[str] = None,
) -> None:
if not os.path.isdir(embedding_root):
raise FileNotFoundError(embedding_root)
subject_ids = []
for f in sorted(os.listdir(embedding_root)):
if not f.endswith(".npy"):
continue
try:
subject_ids.append(int(f.replace(".npy", "")))
except ValueError:
continue
subject_ids = sorted(subject_ids)
if not subject_ids:
raise ValueError("No .npy files in " + embedding_root)
labels = _load_labels(base_dir)
embeddings_list = []
samples: List[Dict[str, Any]] = []
for row, sid in enumerate(subject_ids):
emb = np.load(os.path.join(embedding_root, f"{sid}.npy")).astype(np.float32)
embeddings_list.append(emb)
samples.append({
"row": row,
"user_id": str(sid),
"sample_id": sid,
"label": labels.get(sid, "unknown"),
"session": "1",
})
embeddings = np.stack(embeddings_list, axis=0)
dim = int(embeddings.shape[1])
meta = {"version": "1.0", "embedding_type": embedding_type, "embedding_dim": dim, "samples": samples}
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "index.json"), "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
np.save(os.path.join(output_dir, "embeddings.npy"), embeddings)
print(f"[build_feature_index] Wrote {len(samples)} samples, dim={dim}")
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--embedding_root", required=True, help="Directory of per-subject .npy files")
ap.add_argument("--output_dir", required=True, help="Output dir for index.json and embeddings.npy")
ap.add_argument("--embedding_type", default="sbert_metadata")
ap.add_argument("--base_dir", default=None, help="Optional PPGBP data dir (xlsx) for labels")
a = ap.parse_args()
build_index(a.embedding_root, a.output_dir, a.embedding_type, a.base_dir)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,134 @@
"""
Build PPGBP embeddings under PPGBP_EMBEDDINGS_ROOT with two top-level dirs:
metadata_embs/<emb_type>/<subject_id>.npy
sig_embs/<emb_type>/<subject_id>.npy
For now only metadata embeddings are produced: for each subject, take the xlsx row,
vectorize (one-hot for Sex and Hypertension, numeric for the rest) and save.
Usage:
python -m sc.preprocess.build_ppgbp_embeddings \\
--embeddings_root /path/to/ppgbp_embeddings \\
--data_dir /path/to/ppgbp_data \\
--metadata_emb_type onehot
Then build the feature index (for downstream RecruiterAgent etc.):
python -m sc.preprocess.build_feature_index \\
--embedding_root $PPGBP_EMBEDDINGS_ROOT/metadata_embs/onehot \\
--output_dir $DATA_INDEX_ROOT/metadata_onehot \\
--embedding_type metadata_onehot \\
--base_dir $PPGBP_DATA_DIR
"""
from __future__ import annotations
import argparse
import os
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
# Xlsx column names (header=1)
SUBJECT_ID_COL = "subject_ID"
SEX_COL = "Sex(M/F)"
AGE_COL = "Age(year)"
HEIGHT_COL = "Height(cm)"
WEIGHT_COL = "Weight(kg)"
SBP_COL = "Systolic Blood Pressure(mmHg)"
DBP_COL = "Diastolic Blood Pressure(mmHg)"
HR_COL = "Heart Rate(b/m)"
BMI_COL = "BMI(kg/m^2)"
HYP_COL = "Hypertension"
NUMERIC_COLS = [AGE_COL, HEIGHT_COL, WEIGHT_COL, SBP_COL, DBP_COL, HR_COL, BMI_COL]
CATEGORICAL_COLS = [SEX_COL, HYP_COL]
def _find_xlsx(data_dir: str) -> str:
for f in os.listdir(data_dir):
if f.endswith(".xlsx"):
return os.path.join(data_dir, f)
raise FileNotFoundError(f"No .xlsx in {data_dir}")
def _onehot_and_numeric(df: pd.DataFrame) -> Tuple[np.ndarray, List[str], Dict[str, List[str]]]:
"""
One-hot encode categoricals (Sex, Hypertension) and stack with numeric columns.
Returns (matrix, column_names, categories_used).
"""
out_cols: List[str] = []
pieces: List[np.ndarray] = []
categories_used: Dict[str, List[str]] = {}
for col in CATEGORICAL_COLS:
if col not in df.columns:
continue
vals = df[col].fillna("unknown").astype(str).str.strip()
uniq = sorted(vals.unique())
categories_used[col] = uniq
for u in uniq:
out_cols.append(f"{col}__{u}")
onehot = (vals.values.reshape(-1, 1) == np.array(uniq)).astype(np.float32)
pieces.append(onehot)
for col in NUMERIC_COLS:
if col not in df.columns:
continue
out_cols.append(col)
vec = df[col].fillna(0).astype(np.float32).values.reshape(-1, 1)
pieces.append(vec)
if not pieces:
raise ValueError("No columns found; check xlsx has expected column names")
X = np.hstack(pieces)
return X, out_cols, categories_used
def create_embedding_dirs(embeddings_root: str, metadata_emb_type: str, sig_emb_type: str = "placeholder") -> None:
"""Create metadata_embs/<emb_type>/ and sig_embs/<emb_type>/."""
os.makedirs(embeddings_root, exist_ok=True)
meta_dir = os.path.join(embeddings_root, "metadata_embs", metadata_emb_type)
sig_dir = os.path.join(embeddings_root, "sig_embs", sig_emb_type)
os.makedirs(meta_dir, exist_ok=True)
os.makedirs(sig_dir, exist_ok=True)
def build_metadata_embeddings_onehot(
data_dir: str,
embeddings_root: str,
metadata_emb_type: str = "onehot",
) -> None:
"""
Read xlsx from data_dir, vectorize each subject row (one-hot + numeric), save as
metadata_embs/<metadata_emb_type>/<subject_id>.npy.
"""
xlsx_path = _find_xlsx(data_dir)
df = pd.read_excel(xlsx_path, header=1)
if SUBJECT_ID_COL not in df.columns:
raise ValueError(f"xlsx must have column {SUBJECT_ID_COL}")
X, _cols, _cats = _onehot_and_numeric(df)
subject_ids = df[SUBJECT_ID_COL].astype(int).values
out_dir = os.path.join(embeddings_root, "metadata_embs", metadata_emb_type)
os.makedirs(out_dir, exist_ok=True)
for i, sid in enumerate(subject_ids):
path = os.path.join(out_dir, f"{int(sid)}.npy")
np.save(path, X[i].astype(np.float32))
print(f"[build_ppgbp_embeddings] Wrote {len(subject_ids)} metadata embeddings to {out_dir} (dim={X.shape[1]})")
def main() -> None:
ap = argparse.ArgumentParser(description="Build PPGBP metadata (and optionally signal) embeddings under embeddings_root.")
ap.add_argument("--embeddings_root", required=True, help="Root for metadata_embs/ and sig_embs/ (e.g. PPGBP_EMBEDDINGS_ROOT)")
ap.add_argument("--data_dir", required=True, help="PPGBP data dir containing the xlsx file")
ap.add_argument("--metadata_emb_type", default="onehot", help="Subdir name under metadata_embs/ (e.g. onehot)")
ap.add_argument("--sig_emb_type", default="placeholder", help="Subdir under sig_embs/ (created empty for now)")
args = ap.parse_args()
create_embedding_dirs(args.embeddings_root, args.metadata_emb_type, args.sig_emb_type)
build_metadata_embeddings_onehot(args.data_dir, args.embeddings_root, args.metadata_emb_type)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,76 @@
"""
One-shot: build PPGBP one-hot metadata embeddings and then package them into
the feature index at DATA_INDEX_ROOT for downstream use (RecruiterAgent, etc.).
Step 1: Create metadata_embs/onehot/ and sig_embs/ under PPGBP_EMBEDDINGS_ROOT,
write one-hot metadata vectors per subject.
Step 2: Run build_feature_index to produce index.json + embeddings.npy at
DATA_INDEX_ROOT (compatible with FeatureIndex and RecruiterAgent).
Usage:
export PPGBP_EMBEDDINGS_ROOT=/scratch/.../embeddings_full/ppgbp
export PPGBP_DATA_DIR=/path/to/ppgbp_data # dir with xlsx and 0_subject/
# or: export PPGBP_DATA_ROOT=/path/to/data/tsllm (same as data dir if xlsx lives there)
export DATA_INDEX_ROOT=/path/to/feature_index_output
python -m sc.preprocess.package_ppgbp_embeddings
Or with explicit args:
python -m sc.preprocess.package_ppgbp_embeddings \\
--embeddings_root $PPGBP_EMBEDDINGS_ROOT \\
--data_dir $PPGBP_DATA_DIR \\
--index_output_dir $DATA_INDEX_ROOT \\
--metadata_emb_type onehot \\
--embedding_type metadata_onehot
"""
from __future__ import annotations
import argparse
import os
import sys
def main() -> None:
ap = argparse.ArgumentParser(description="Build PPGBP metadata embeddings and package feature index.")
ap.add_argument("--embeddings_root", default=None, help="Default: PPGBP_EMBEDDINGS_ROOT env")
ap.add_argument("--data_dir", default=None, help="PPGBP data dir (xlsx + 0_subject/); default: PPGBP_DATA_DIR or PPGBP_DATA_ROOT env")
ap.add_argument("--index_output_dir", default=None, help="Where to write index.json + embeddings.npy; default: DATA_INDEX_ROOT env")
ap.add_argument("--metadata_emb_type", default="onehot")
ap.add_argument("--embedding_type", default="metadata_onehot", help="embedding_type string in index.json")
args = ap.parse_args()
embeddings_root = args.embeddings_root or os.environ.get("PPGBP_EMBEDDINGS_ROOT")
data_dir = args.data_dir or os.environ.get("PPGBP_DATA_DIR") or os.environ.get("PPGBP_DATA_ROOT") or embeddings_root
index_output_dir = args.index_output_dir or os.environ.get("DATA_INDEX_ROOT")
if not embeddings_root:
print("Set PPGBP_EMBEDDINGS_ROOT or pass --embeddings_root", file=sys.stderr)
sys.exit(1)
if not os.path.isdir(data_dir):
print(f"Data dir not found: {data_dir}", file=sys.stderr)
sys.exit(1)
if not index_output_dir:
print("Set DATA_INDEX_ROOT or pass --index_output_dir", file=sys.stderr)
sys.exit(1)
from sc.preprocess.build_ppgbp_embeddings import (
create_embedding_dirs,
build_metadata_embeddings_onehot,
)
create_embedding_dirs(embeddings_root, args.metadata_emb_type)
build_metadata_embeddings_onehot(data_dir, embeddings_root, args.metadata_emb_type)
embedding_root_for_index = os.path.join(embeddings_root, "metadata_embs", args.metadata_emb_type)
from sc.preprocess.build_feature_index import build_index
build_index(
embedding_root=embedding_root_for_index,
output_dir=index_output_dir,
embedding_type=args.embedding_type,
base_dir=data_dir,
)
print(f"[package_ppgbp_embeddings] Feature index written to {index_output_dir}")
if __name__ == "__main__":
main()

390
sc/run_recruiter.py Normal file
View File

@@ -0,0 +1,390 @@
"""
Recruiter-based Self-Consistency runner for PPGBP.
Uses RecruiterAgent (MAB) to select example sets, self_certainty(logits) as reward,
and borda_vote for aggregation. Requires PPGBPLoader and optional feature index.
Usage:
python -m sc.run_recruiter sc/config/ppgbp_recruiter.yaml
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
from collections import deque
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import yaml
from fire import Fire
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sc.core.data_loader import PPGBPLoader, get_loaders
from sc.core.example_queue import Queue
from sc.core.model import Model, load_models
from sc.core.recruiter_agent import RecruiterAgent, borda_vote, self_certainty
from sc.core.recruiter_agent import FeatureIndex
def _ppgbp_sample_to_dict(metadata: Dict, signal, subject_id: int, label_key: str = "label") -> Dict[str, Any]:
"""Turn (metadata, signal) into a sample dict with user_id, sample_id, label, features."""
# Try to find label in metadata using label_key, then "hypertension", then "Hypertension", then "class", then "Class"
label = str(metadata.get(label_key))
if label == "None" or label == "unknown":
for key in ["hypertension", "Hypertension", "class", "Class"]:
val = metadata.get(key)
if val is not None:
label = str(val)
break
if label == "None":
label = "unknown"
features = {k: str(v) for k, v in metadata.items() if k != label_key and k.lower() != "label"}
return {
"user_id": str(subject_id),
"sample_id": subject_id,
"idx": subject_id,
"label": label,
"features": features,
}
def _build_prompt(system_msg: str, task_info: str, classes_info: List[str], test_sample: Dict, examples: List[Dict]) -> List[Dict[str, str]]:
"""Build chat messages for the model."""
parts = [f"**Task**: {task_info}\n\n**Classes**: {', '.join(classes_info)}\n\n"]
parts.append("**Examples**\n")
for ex in examples:
parts.append(f"*Example of {ex['label']}*:\n")
for k, v in (ex.get("features") or {}).items():
parts.append(f" - {k}: {v}\n")
parts.append("\n**Current sample**\n")
for k, v in (test_sample.get("features") or {}).items():
parts.append(f" - {k}: {v}\n")
parts.append(f"\nRespond in JSON: {{\"REASON\": \"...\", \"CONFIDENCE\": 0.0-1.0, \"ANSWER\": \"<class>\"}}")
content = "".join(parts)
return [
{"role": "system", "content": system_msg},
{"role": "user", "content": content},
]
def _parse_answer(text: str, classes: List[str]) -> Optional[str]:
"""Extract ANSWER from JSON in text."""
try:
start = text.find("{")
end = text.rfind("}") + 1
if start >= 0 and end > start:
obj = json.loads(text[start:end])
ans = (obj.get("ANSWER") or "").strip()
if ans in classes:
return ans
except Exception:
pass
return None
def make_gt_scorer(
num_arms: int,
best_arm_id: int = 20,
sigma: float = 5.0,
) -> Callable[[int], float]:
"""
Ground-truth: deterministic. The distribution *over arms* is normal: reward(arm_id)
is a Gaussian in arm index with peak at best_arm_id. So reward(i) = exp(-(i - best)^2 / (2 sigma^2)).
Score = -reward so lower = better. No randomness.
"""
best_arm_id = max(0, min(best_arm_id, num_arms - 1))
sigma = max(1e-6, float(sigma))
def scorer(arm_id: int) -> float:
if arm_id < 0 or arm_id >= num_arms:
return 0.0
reward = np.exp(-((arm_id - best_arm_id) ** 2) / (2 * sigma**2))
return -float(reward)
return scorer
async def run_single_user(
config: Dict[str, Any],
model_pool: Optional[Any],
test_loader: Optional[PPGBPLoader],
train_loader: PPGBPLoader,
seed: int,
) -> List[Dict[str, Any]]:
"""Run recruiter pipeline. If model_pool is None, use ground-truth scorer (arm_id -> score) for MAB convergence."""
queue_size = config.get("queue_size", 5)
recruit_size = config.get("recruit_size", 2)
n_way = config.get("n_way", 2)
k_shot = config.get("k_shot", 1)
task_info = config.get("task_info", "Classify the subject.")
classes_info = config.get("classes_info", ["Normal", "Hypertension"])
system_msg = config.get("system_message", "You are a helpful assistant. Respond in JSON with REASON, CONFIDENCE, ANSWER.")
use_gt_only = model_pool is None
np.random.seed(seed)
train_loader._indices = np.arange(len(train_loader.subject_ids))
if train_loader.shuffle:
train_loader.rng.shuffle(train_loader._indices)
example_dataset = []
for idx in train_loader._indices:
sid = train_loader.subject_ids[idx]
meta = train_loader._get_metadata_dict(sid)
signal = train_loader._load_signal(sid)
example_dataset.append(_ppgbp_sample_to_dict(meta, signal, int(sid)))
if len(example_dataset) < n_way * k_shot:
print("[run_recruiter] Not enough train samples")
return []
classes = list({s["label"] for s in example_dataset})
if not classes:
classes = classes_info
class_indices = {c: [i for i, s in enumerate(example_dataset) if s["label"] == c] for c in classes}
for c in classes:
if not class_indices[c]:
class_indices[c] = [0]
dataset_idx = {(str(s["user_id"]), s["sample_id"]): i for i, s in enumerate(example_dataset)}
feature_index = None
if config.get("feature_root") and os.path.isdir(config["feature_root"]):
try:
feature_index = FeatureIndex(config["feature_root"])
except Exception as e:
print(f"[run_recruiter] Feature index load failed: {e}")
recruiter = RecruiterAgent(
source_samples=example_dataset,
classes=classes,
feature_index=feature_index,
target_user_id=None,
k_shot=k_shot,
queue_size=queue_size,
num_arms=config.get("num_arms", 50),
n_das=config.get("n_das", 5),
delta=config.get("delta", 0.05),
epsilon=config.get("epsilon", 0.11),
seed=seed,
)
num_arms = len(recruiter.pool.arms)
if use_gt_only:
gt_steps = config.get("gt_steps", 200)
gt_best_arm_id = config.get("gt_best_arm_id", 20)
gt_sigma = config.get("gt_sigma", 5.0)
gt_scorer = make_gt_scorer(num_arms, best_arm_id=gt_best_arm_id, sigma=gt_sigma)
print(f"[run_recruiter] GT-only mode: {gt_steps} steps, best_arm_id={gt_best_arm_id}, sigma={gt_sigma}")
initial_sets = recruiter.recruit(queue_size)
initial_cases = []
slot_arms = []
for arm_idx, ex_set in initial_sets:
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
if key in dataset_idx:
case.append(dataset_idx[key])
else:
for i, s in enumerate(example_dataset):
if s.get("user_id") == str(d.get("user_id")) and s.get("sample_id") == d.get("sample_id"):
case.append(i)
break
if len(case) >= n_way:
initial_cases.append(case[:n_way])
slot_arms.append(arm_idx)
while len(initial_cases) < queue_size and len(initial_sets) > len(initial_cases):
arm_idx, ex_set = initial_sets[len(initial_cases)]
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
case.append(dataset_idx.get(key, 0))
if len(case) < n_way:
case.extend([0] * (n_way - len(case)))
initial_cases.append(case[:n_way])
slot_arms.append(arm_idx)
ex_queue = Queue(class_indices, queue_size)
ex_queue._queue = deque(initial_cases[:queue_size], maxlen=queue_size)
for case in ex_queue._queue:
ex_queue._register_stats(case)
slot_arms = slot_arms[:queue_size]
results = []
processed = 0
cumulative_correct = 0
if use_gt_only:
# Loop: query gt_scorer(arm_id) for each slot, update MAB, recruit, update queue
for step in range(gt_steps):
ex_queue.set_current_time(step)
queue_cases = list(ex_queue._queue)
run_results = []
for slot_i in range(len(slot_arms)):
arm_idx = slot_arms[slot_i]
score = gt_scorer(arm_idx)
response_dict = {"ANSWER": None, "REASON": "gt", "CONFIDENCE": 0.5}
run_results.append((arm_idx, response_dict, score))
recruiter.update(run_results)
new_sets = recruiter.recruit(recruit_size)
new_cases = []
new_arm_ids = []
for arm_idx, ex_set in new_sets:
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
case.append(dataset_idx.get(key, 0))
if len(case) < n_way:
case.extend([0] * (n_way - len(case)))
new_cases.append(case[:n_way])
new_arm_ids.append(arm_idx)
ex_queue.update_with_recruiter(run_results, new_cases, recruit_size)
scores_with_idx = [(i, run_results[i][2]) for i in range(len(run_results))]
scores_with_idx.sort(key=lambda x: x[1], reverse=True)
to_evict = {scores_with_idx[j][0] for j in range(min(recruit_size, len(scores_with_idx)))}
kept = [i for i in range(len(slot_arms)) if i not in to_evict]
slot_arms = [slot_arms[i] for i in kept] + new_arm_ids[:recruit_size]
results.append({
"sample_idx": step,
"answer": None,
"ground_truth": None,
"is_correct": None,
"gt_step": step,
"queue_arms": list(slot_arms),
"step_rewards": [r[2] for r in run_results],
})
print(f"[GT Summary] Final queue arms: {slot_arms}")
return results
# Model mode: iterate test loader, call LLM, self_certainty, borda_vote
for step, (metadata, signal) in enumerate(test_loader):
test_sample = _ppgbp_sample_to_dict(metadata, signal, int(metadata.get("subject_ID", step)))
ex_queue.set_current_time(processed)
queue_cases = list(ex_queue._queue)
run_results = []
for slot_i, case in enumerate(queue_cases):
examples = [example_dataset[i] for i in case if i < len(example_dataset)]
if len(examples) < n_way:
continue
messages = _build_prompt(system_msg, task_info, classes_info, test_sample, examples)
text, logits = await model_pool.invoke(messages)
score = self_certainty(logits)
parsed = _parse_answer(text, classes_info)
response_dict = {"ANSWER": parsed, "REASON": text[:200], "CONFIDENCE": 0.5}
arm_idx = slot_arms[slot_i] if slot_i < len(slot_arms) else slot_i
run_results.append((arm_idx, response_dict, score))
if not run_results:
processed += 1
continue
answer = borda_vote(run_results, valid_classes=classes_info)
gt = test_sample.get("label", "")
is_correct = answer == gt
cumulative_correct += 1 if is_correct else 0
acc = cumulative_correct / (processed + 1)
recruiter.update(run_results)
new_sets = recruiter.recruit(recruit_size)
new_cases = []
new_arm_ids = []
for arm_idx, ex_set in new_sets:
case = []
for d in ex_set:
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
case.append(dataset_idx.get(key, 0))
if len(case) < n_way:
case.extend([0] * (n_way - len(case)))
new_cases.append(case[:n_way])
new_arm_ids.append(arm_idx)
ex_queue.update_with_recruiter(run_results, new_cases, recruit_size)
scores_with_idx = [(i, run_results[i][2]) for i in range(len(run_results))]
scores_with_idx.sort(key=lambda x: x[1], reverse=True)
to_evict = {scores_with_idx[j][0] for j in range(min(recruit_size, len(scores_with_idx)))}
kept = [i for i in range(len(slot_arms)) if i not in to_evict]
slot_arms = [slot_arms[i] for i in kept] + new_arm_ids[:recruit_size]
results.append({
"sample_idx": processed,
"answer": answer,
"ground_truth": gt,
"is_correct": is_correct,
"cumulative_accuracy": acc,
})
processed += 1
return results
def _expand_env_in_config(obj: Any) -> Any:
"""Recursively expand $VAR and ${VAR} in config strings."""
if isinstance(obj, dict):
return {k: _expand_env_in_config(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_expand_env_in_config(v) for v in obj]
if isinstance(obj, str):
return os.path.expandvars(obj)
return obj
def main(config_path: str) -> None:
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
config = _expand_env_in_config(config)
if not config.get("model_path") and config.get("models"):
config["model_path"] = config["models"][0] if isinstance(config["models"], list) else config["models"]
model_path = config.get("model_path") or ""
if isinstance(model_path, str):
model_path = model_path.strip()
use_gt_only = not model_path
if use_gt_only:
config["use_gt_only"] = True
print("[run_recruiter] model_path is null: using ground-truth scorer (arm_id -> score) for MAB convergence")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if "log_path" in config:
config["log_path"] = f"{config['log_path']}_{timestamp}"
pool = None
if not use_gt_only:
pool = load_models(
[model_path],
temperature=config.get("temperature", 0.7),
max_new_tokens=config.get("max_new_tokens", 256),
)
data_path = config["data_path"]
train_loader, test_loader = get_loaders(data_path, seed=config.get("seed", 42))
if use_gt_only:
test_loader = None # not used in GT loop
seeds = range(config.get("num_seeds", 1))
all_results = []
for seed in seeds:
res = asyncio.run(run_single_user(config, pool, test_loader, train_loader, seed))
all_results.extend(res)
if use_gt_only:
print(f"GT mode: completed {len(all_results)} steps")
else:
correct = sum(1 for r in all_results if r.get("is_correct"))
total = len(all_results)
acc = correct / total if total else 0
print(f"Accuracy: {correct}/{total} = {acc:.4f}")
log_path = config.get("log_path", "./logs/ppgbp_recruiter")
os.makedirs(log_path, exist_ok=True)
with open(os.path.join(log_path, "results.json"), "w") as f:
json.dump({
"use_gt_only": use_gt_only,
"accuracy": (sum(1 for r in all_results if r.get("is_correct")) / len(all_results)) if all_results and not use_gt_only else None,
"correct": sum(1 for r in all_results if r.get("is_correct")) if not use_gt_only else None,
"total": len(all_results),
"results": all_results,
}, f, indent=2)
if __name__ == "__main__":
Fire(main)