Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
216bd2ecc3 |
198
.claude/CLAUDE.md
Normal file
198
.claude/CLAUDE.md
Normal file
@@ -0,0 +1,198 @@
|
||||
# Project Memory: tsllm_personalization_icl
|
||||
|
||||
## Project Overview
|
||||
|
||||
Research project on **personalized time-series classification using LLMs with In-Context Learning (ICL)**. The core idea is to use LLMs (via Ollama or HuggingFace) to classify sensor/physiological time-series data (e.g., EEG sleep stages, PPG blood pressure), where ICL examples are selected via different strategies (random vs. similarity-based) to study personalization.
|
||||
|
||||
Current branch: `case` (main branch: `main`)
|
||||
|
||||
---
|
||||
|
||||
## Repository Structure
|
||||
|
||||
```
|
||||
tsllm_personalization_icl/
|
||||
├── run.py # Main entry point (ICL experiment runner)
|
||||
├── config/
|
||||
│ └── sleepedf.yaml # Config: data path, models, selection criteria, log path
|
||||
├── core/ # Base pipeline (ICL experiments)
|
||||
│ ├── model.py # Model/AsyncModelPool (Ollama or LangChain init_chat_model)
|
||||
│ ├── agent.py # Base Agent class (memory, logging, JSON parsing)
|
||||
│ ├── sensing_agent.py # SensingAgent: ICL classify + evaluate + reflect
|
||||
│ ├── data_loader.py # DataLoader: test/example splits + similarity-based selection
|
||||
│ └── embedding_index.py # EmbeddingIndex: cosine similarity over Chronos-2 embeddings
|
||||
├── sc/ # Self-Consistency (SC) variant pipeline
|
||||
│ ├── run_sc.py # SC experiment runner (main entry: `python -m sc.run_sc config.yaml`)
|
||||
│ ├── run_confidence_based.py
|
||||
│ ├── run_consistency_based.py
|
||||
│ ├── run_sc_queue_random.py
|
||||
│ ├── run_usc.py
|
||||
│ ├── core/
|
||||
│ │ ├── model.py # HuggingFace CausalLM wrapper (returns text + logits)
|
||||
│ │ ├── agent.py # SC base agent
|
||||
│ │ ├── scagent.py # SCAgent: single-pass interpret() with REASON/CONFIDENCE/ANSWER
|
||||
│ │ ├── agent_pool.py # AgentPool: parallel interpret + majority voting
|
||||
│ │ ├── example_queue.py # Queue: priority queue of example sets, updated by confidence
|
||||
│ │ ├── data_loader.py # DataLoader + InMemoryDataLoader + PPGBPLoader
|
||||
│ │ ├── judge_agent.py # JudgeAgent
|
||||
│ │ ├── majority_voting.py # MajorityVoting utilities
|
||||
│ │ └── model_utils.py
|
||||
│ ├── analysis/
|
||||
│ │ └── analyze_sc_results.py
|
||||
│ ├── preprocess/
|
||||
│ │ └── shuffle_data.py
|
||||
│ ├── debug_log.py # Structured debug logging helpers
|
||||
│ ├── logger.py
|
||||
│ ├── hf_api.py
|
||||
│ └── ollama_test.py
|
||||
├── analysis/
|
||||
│ ├── ppgbp_loader.py
|
||||
│ ├── user_similarity/ # Embedding analysis scripts
|
||||
│ │ ├── chronos2/ # Chronos-2 embeddings for SleepEDF
|
||||
│ │ ├── chronos2_ppgbp/ # Chronos-2 embeddings for PPGBP
|
||||
│ │ ├── labram/ # LaBraM embeddings
|
||||
│ │ ├── sbert/
|
||||
│ │ ├── sbert_metadata/
|
||||
│ │ └── sbert_metadata_ppgbp/ # SBERT metadata embeddings for PPGBP
|
||||
│ ├── analyze_data.ipynb
|
||||
│ └── analyze_preliminary.ipynb
|
||||
├── preprocess/
|
||||
│ ├── dhedfreader.py
|
||||
│ └── preprocess_SleepEDF.py
|
||||
├── utils/
|
||||
│ ├── kill_ollamas.sh
|
||||
│ └── launch_ollamas.sh
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### Datasets
|
||||
- **SleepEDF**: EEG-based sleep stage classification (W, N1, N2, N3, REM).
|
||||
- **PPGBP**: PPG-based blood pressure prediction.
|
||||
- Data format: `<data_path>/<user_id>/1/` (train/examples, HuggingFace dataset on disk) and `<data_path>/<user_id>/2/` (test split). Metadata in `info.json` (keys: `task`, `class`, `feature`).
|
||||
|
||||
### ICL Selection Strategies (core/data_loader.py)
|
||||
- `out_random`: Random examples from OTHER users (cross-user, random)
|
||||
- `in_random`: Random examples from SAME user (personalized, random)
|
||||
- `out_similar`: Most similar examples from OTHER users (Chronos-2 embedding similarity)
|
||||
- `in_similar`: Most similar examples from SAME user (embedding similarity)
|
||||
|
||||
Similarity uses cosine similarity over Chronos-2 embeddings, one example per class (balanced).
|
||||
|
||||
### Models
|
||||
- **core/model.py**: Supports Ollama (local/remote via `ollama:url:<host>/<model>`) and any LangChain-supported model (OpenAI, Together, etc.)
|
||||
- **sc/core/model.py**: HuggingFace CausalLM loaded via `transformers`, returns text + logits. Used for SC experiments.
|
||||
|
||||
### Self-Consistency (sc/) Pipeline
|
||||
The SC pipeline uses a **priority queue of example sets**:
|
||||
- `Queue` (sc/core/example_queue.py): holds `capacity` example-sets (one set = one example per class). Updated each step by agent confidence scores — highest-confidence sets are kept, lowest evicted and replaced with a random new set.
|
||||
- `AgentPool` (sc/core/agent_pool.py): runs one `SCAgent` per queue slot in parallel, aggregates via majority vote, tracks confidence and consistency.
|
||||
- `SCAgent` (sc/core/scagent.py): single LLM call, returns `{REASON, CONFIDENCE, ANSWER}` JSON.
|
||||
|
||||
### Base Agent (core/agent.py)
|
||||
- Three memory tiers: `long_term_memory`, `short_term_memory`, `volatile_memory`
|
||||
- `invoke()`: async, appends messages to memory, calls model pool
|
||||
- `safe_parse_json()` / `safe_parse_json_list()`: robust JSON parsing with cleanup
|
||||
- Token counting via `tiktoken` (gpt-3.5-turbo encoding as proxy)
|
||||
|
||||
### SensingAgent (core/sensing_agent.py)
|
||||
Extends Agent. Methods:
|
||||
- `solve()`: classify a sample with ICL examples → `{REASON, ANSWER}`
|
||||
- `interpret()`: same as solve but without logging ground truth
|
||||
- `evaluate()`: evaluate another agent's answer (for multi-agent debate)
|
||||
- `reflect()`: refine answer based on peer evaluations
|
||||
|
||||
---
|
||||
|
||||
## Running Experiments
|
||||
|
||||
### Basic ICL run
|
||||
```bash
|
||||
python run.py run config/sleepedf.yaml
|
||||
```
|
||||
|
||||
### Compare multiple selection criteria
|
||||
```bash
|
||||
python run.py compare config/sleepedf.yaml \
|
||||
--criteria_list="out_random,in_random,out_similar,in_similar" \
|
||||
--embedding_path="./embeddings_full"
|
||||
```
|
||||
|
||||
### Self-Consistency run
|
||||
```bash
|
||||
python -m sc.run_sc sc/config/sleepedf_sc.yaml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Config (sleepedf.yaml) Key Fields
|
||||
- `data_path`: root of dataset (contains `info.json` and per-user dirs)
|
||||
- `log_path`: output directory for results and logs
|
||||
- `num_seeds`: number of random seeds to run
|
||||
- `num_examples`: ICL examples per class (default 1)
|
||||
- `selection_criteria`: `out_random` | `in_random` | `out_similar` | `in_similar`
|
||||
- `embedding_path`: path to pre-computed Chronos-2 embeddings (required for `*_similar`)
|
||||
- `models`: list of model specs (Ollama URLs or model IDs)
|
||||
|
||||
In SC configs, additional fields:
|
||||
- `queue_size`: number of example sets to maintain in the priority queue
|
||||
- `temperature`: LLM sampling temperature
|
||||
- `max_new_tokens`: max tokens for generation
|
||||
- `example_pool`: `"out"` (other users) or `"in"` (same user)
|
||||
- `continuous`: whether queue persists across test samples
|
||||
|
||||
---
|
||||
|
||||
## Data Format Details
|
||||
|
||||
Each HuggingFace dataset sample:
|
||||
```python
|
||||
{
|
||||
"user_id": str,
|
||||
"session_id": str, # "1" = train, "2" = test
|
||||
"idx": int,
|
||||
"label": str, # class name from info.json["class"]
|
||||
"features": dict, # str -> float/str (sensor features formatted for prompt)
|
||||
"data": dict, # optional raw data
|
||||
}
|
||||
```
|
||||
|
||||
`info.json`:
|
||||
```json
|
||||
{
|
||||
"task": "Sleep stage classification from EEG",
|
||||
"class": {"W": "Wake", "N1": "...", ...},
|
||||
"feature": "EEG channel description for the LLM..."
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## PPGBP Dataset (sc/core/data_loader.py)
|
||||
`PPGBPLoader`: reads xlsx metadata + signal `.txt` files from `0_subject/`. 80/20 train/test split at subject level. Metadata embeddings (SBERT) auto-generated and cached per subject as `.npy` files under `PPGBP_METADATA_EMBEDDINGS_ROOT` (set via env var or `.env`).
|
||||
|
||||
---
|
||||
|
||||
## Dependencies (requirements.txt)
|
||||
- `langchain`, `langchain_ollama`, `langchain_openai`, `langchain_together` — LLM backends
|
||||
- `datasets` — HuggingFace datasets (on-disk storage)
|
||||
- `chronos` — Chronos-2 time-series embeddings
|
||||
- `sentence_transformers` — SBERT for metadata embeddings
|
||||
- `transformers`, `torch` — HuggingFace model loading (SC pipeline)
|
||||
- `tiktoken` — token counting
|
||||
- `fire` — CLI argument parsing
|
||||
- `mne`, `neurokit2` — EEG/biosignal preprocessing
|
||||
- `numpy`, `pandas`, `scikit_learn`, `scipy`, `matplotlib`
|
||||
|
||||
---
|
||||
|
||||
## Notes / Patterns
|
||||
- Experiments sample every 10th test item (`if idx % 10 != 0: continue` in `run.py`).
|
||||
- SC runner currently hardcoded to first user only for testing (`users[:1]`).
|
||||
- Log structure: `<log_path>/<user_id>/<sample_idx>/<seed>/` with `summary.txt`, `log.txt`, `tokens.txt`.
|
||||
- Embedding index uses cosine similarity (L2-normalized dot product), filtered by user and session.
|
||||
- `DataLoader` in `sc/` is simpler (no similarity selection); similarity lives in `core/`.
|
||||
- `InMemoryDataLoader` and `prepare_dataset_for_sc()` allow integrating new datasets without writing to disk.
|
||||
0
.claude/plan.md
Normal file
0
.claude/plan.md
Normal file
@@ -630,10 +630,10 @@ class SBERT_Metadata:
|
||||
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||
"""
|
||||
Load saved embeddings dataset from disk.
|
||||
|
||||
|
||||
Args:
|
||||
embedding_dir: Directory path containing the saved HuggingFace dataset
|
||||
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with embeddings and metadata
|
||||
"""
|
||||
@@ -642,6 +642,94 @@ class SBERT_Metadata:
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Save metadata embeddings by user ID (for PPGBP dataloader)
|
||||
# =============================================================================
|
||||
|
||||
def _normalize_sex_for_sbert(sex: Any) -> Optional[int]:
|
||||
"""Convert metadata sex to int for SBERT textualize (1=Female, 0=Male)."""
|
||||
if sex is None or (isinstance(sex, float) and np.isnan(sex)):
|
||||
return None
|
||||
s = str(sex).strip().upper()
|
||||
if s in ("F", "1"):
|
||||
return 1
|
||||
if s in ("M", "0"):
|
||||
return 0
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_hypertension_for_sbert(ht: Any) -> Optional[int]:
|
||||
"""Convert metadata hypertension to int if numeric."""
|
||||
if ht is None or (isinstance(ht, float) and np.isnan(ht)):
|
||||
return None
|
||||
try:
|
||||
return int(float(ht))
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def save_metadata_embeddings_by_userid(
|
||||
data_root: str,
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
embeddings_root: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Compute SBERT metadata embeddings for each user and save one file per user
|
||||
under embeddings_root as {user_id}.npy (e.g. 100.npy).
|
||||
|
||||
Uses PPGBP_METADATA_EMBEDDINGS_ROOT from environment if embeddings_root
|
||||
is not provided. Load .env from project root if present so the env var is set.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing 0_subject/ with PPG-BP data.
|
||||
subject_path: Path to PPGBP metadata xlsx (subject_ID, age, sex, etc.).
|
||||
embeddings_root: Directory to write {user_id}.npy files. Defaults to
|
||||
os.environ["PPGBP_METADATA_EMBEDDINGS_ROOT"].
|
||||
"""
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
# Load .env from repo root (parent of analysis/)
|
||||
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
load_dotenv(os.path.join(_repo_root, ".env"))
|
||||
except ImportError:
|
||||
pass
|
||||
root = embeddings_root or os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT")
|
||||
if not root:
|
||||
raise ValueError(
|
||||
"embeddings_root not provided and PPGBP_METADATA_EMBEDDINGS_ROOT not set. "
|
||||
"Set the env var or pass embeddings_root."
|
||||
)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
embedder = SBERT_Metadata()
|
||||
|
||||
user_ids = sorted(subject_metadata.keys())
|
||||
if not user_ids:
|
||||
raise ValueError(f"No subjects found in {subject_path}")
|
||||
|
||||
for user_id in user_ids:
|
||||
meta = subject_metadata[user_id]
|
||||
sex = _normalize_sex_for_sbert(meta.get("sex"))
|
||||
age = meta.get("age")
|
||||
height = meta.get("height")
|
||||
weight = meta.get("weight")
|
||||
sbp = meta.get("sbp")
|
||||
dbp = meta.get("dbp")
|
||||
hr = meta.get("hr")
|
||||
bmi = meta.get("bmi")
|
||||
hypertension = _normalize_hypertension_for_sbert(meta.get("hypertension"))
|
||||
if age is not None and isinstance(age, float) and np.isnan(age):
|
||||
age = None
|
||||
|
||||
emb = embedder.compute_embedding_from_metadata(
|
||||
[sex], [age], [height], [weight], [sbp], [dbp], [hr], [bmi], [hypertension]
|
||||
)
|
||||
path = os.path.join(root, f"{user_id}.npy")
|
||||
np.save(path, emb[0])
|
||||
print(f"[DONE] Saved {len(user_ids)} metadata embeddings under {root}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading Utilities
|
||||
# =============================================================================
|
||||
|
||||
1029
sample_case.py
Normal file
1029
sample_case.py
Normal file
File diff suppressed because it is too large
Load Diff
32
sc/config/ppgbp_recruiter.yaml
Normal file
32
sc/config/ppgbp_recruiter.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
# PPGBP Recruiter (MAB-based example set selection) config
|
||||
# Usage: python -m sc.run_recruiter sc/config/ppgbp_recruiter.yaml
|
||||
# When model_path is null: use ground-truth scorer (arm_id -> score), no LLM; run gt_steps to test MAB convergence.
|
||||
|
||||
|
||||
data_path: $PPGBP_DATA_ROOT
|
||||
feature_root: $DATA_INDEX_ROOT/metadata_onehot # index.json + embeddings.npy
|
||||
log_path: ./logs/ppgbp_recruiter
|
||||
model_path: null # set to a HF model path to use real LLM + self_certainty
|
||||
|
||||
queue_size: 5
|
||||
recruit_size: 2
|
||||
n_way: 2
|
||||
k_shot: 3
|
||||
num_arms: 50
|
||||
|
||||
# When model_path is null: deterministic reward = Gaussian over arm index, peak at gt_best_arm_id
|
||||
gt_steps: 200
|
||||
gt_best_arm_id: 20
|
||||
gt_sigma: 5.0
|
||||
|
||||
n_das: 5
|
||||
delta: 0.05
|
||||
epsilon: 0.11
|
||||
|
||||
temperature: 0.7
|
||||
max_new_tokens: 256
|
||||
num_seeds: 1
|
||||
seed: 42
|
||||
|
||||
task_info: "Classify the subject from PPG-BP metadata."
|
||||
classes_info: ["Normal", "Prehypertension", "Stage 1 hypertension", "Stage 2 hypertension" ]
|
||||
@@ -7,17 +7,44 @@ Expects the following directory structure:
|
||||
<path>/
|
||||
info.json # dataset metadata (task, classes, features)
|
||||
<user_id>/
|
||||
1/ # examples (train split)
|
||||
2/ # test split
|
||||
1/ # examples (train split) - HuggingFace dataset on disk
|
||||
2/ # test split - HuggingFace dataset on disk
|
||||
<other_user>/
|
||||
1/
|
||||
...
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
Integrating a new dataset
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
The pipeline expects:
|
||||
|
||||
1) Directory layout (for DataLoader / ShuffledDataLoader):
|
||||
- <data_path>/info.json
|
||||
- <data_path>/<user_id>/1/ and <data_path>/<user_id>/2/ (HuggingFace datasets saved with datasets.Dataset.save_to_disk)
|
||||
|
||||
2) info.json schema:
|
||||
- "task": str (e.g. "Sleep stage classification from EEG")
|
||||
- "class": dict mapping label -> description (e.g. {"W": "Wake", "N1": "Non-REM 1", ...})
|
||||
- "feature": str (sensor/feature description for the LLM)
|
||||
|
||||
3) Each sample (in test and example datasets) must be a dict with:
|
||||
- "label": str (class name, must be one of the keys in info.json["class"])
|
||||
- "features": dict of str -> value (e.g. {"Fpz-Cz": "0.12, -0.05, ...", "Pz-Oz": "..."}); values are formatted for the prompt
|
||||
|
||||
Option A - Convert your dataset to this format:
|
||||
- Create info.json and per-user dirs; build HuggingFace Dataset with columns "label" and "features", then save_to_disk to <user_id>/1 and <user_id>/2.
|
||||
- Use prepare_dataset_for_sc() in this module to write from in-memory lists.
|
||||
|
||||
Option B - Use in-memory data without changing directory layout:
|
||||
- Use InMemoryDataLoader(metadata, test_samples, example_samples) which implements the same interface as DataLoader.
|
||||
- Use it in run_sc by constructing it and passing to run_single_task; for run_confidence_based / run_consistency_based / run_sc_queue_random you currently need the directory layout (or extend those runners to accept a loader instance).
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from glob import glob
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
|
||||
@@ -26,10 +53,20 @@ class DataLoader:
|
||||
"""Loads a target user's test set and ICL examples from other users."""
|
||||
|
||||
def __init__(
|
||||
self, path: str, user_id: str, shuffle: bool = False, seed: int = 0
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
data_path: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
metadata_path = os.path.join(path, "info.json")
|
||||
target_user_path = os.path.join(path, user_id, "2")
|
||||
# Support both path= and data_path= for compatibility with run_sc
|
||||
root = path or data_path
|
||||
if root is None or user_id is None:
|
||||
raise ValueError("path (or data_path) and user_id are required")
|
||||
metadata_path = os.path.join(root, "info.json")
|
||||
target_user_path = os.path.join(root, user_id, "2")
|
||||
|
||||
if not os.path.exists(metadata_path) or not os.path.exists(target_user_path):
|
||||
return
|
||||
@@ -44,7 +81,7 @@ class DataLoader:
|
||||
# Example set: train splits from all *other* users
|
||||
example_paths = [
|
||||
os.path.join(user_path, "1")
|
||||
for user_path in glob(os.path.join(path, "*"))
|
||||
for user_path in glob(os.path.join(root, "*"))
|
||||
if os.path.isdir(user_path) and os.path.basename(user_path) != user_id
|
||||
]
|
||||
example_parts = [datasets.load_from_disk(p) for p in example_paths]
|
||||
@@ -87,6 +124,88 @@ class DataLoader:
|
||||
return list(self.metadata["class"].keys())
|
||||
|
||||
|
||||
class InMemoryDataLoader:
|
||||
"""
|
||||
DataLoader interface backed by in-memory lists. Use this to run the SC
|
||||
pipeline on a new dataset without writing the directory layout to disk.
|
||||
|
||||
Implements the same interface as DataLoader: __len__, __getitem__, __iter__,
|
||||
get_examples(), get_metadata(), get_sensor_info(), get_task_info(), get_classes_info().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata: Dict[str, Any],
|
||||
test_samples: List[Dict[str, Any]],
|
||||
example_samples: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
metadata: Must have "task", "class" (dict label -> description), "feature" (str).
|
||||
test_samples: List of dicts with "label" and "features" (dict).
|
||||
example_samples: Same schema; used as ICL example pool.
|
||||
"""
|
||||
self.metadata = metadata
|
||||
self.test_dataset = datasets.Dataset.from_list(test_samples)
|
||||
self.example_dataset = datasets.Dataset.from_list(example_samples)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.test_dataset)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
return self.test_dataset[idx]
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.test_dataset
|
||||
|
||||
def get_examples(self) -> datasets.Dataset:
|
||||
return self.example_dataset
|
||||
|
||||
def get_metadata(self) -> Dict[str, Any]:
|
||||
return self.metadata
|
||||
|
||||
def get_sensor_info(self) -> str:
|
||||
return self.metadata["feature"]
|
||||
|
||||
def get_task_info(self) -> str:
|
||||
classes_info = "\n".join(
|
||||
f" - {k}: {v}" for k, v in self.metadata["class"].items()
|
||||
)
|
||||
return f"**Task**:\n{self.metadata['task']}\n\n**Classes**:\n{classes_info}"
|
||||
|
||||
def get_classes_info(self) -> List[str]:
|
||||
return list(self.metadata["class"].keys())
|
||||
|
||||
|
||||
def prepare_dataset_for_sc(
|
||||
output_path: str,
|
||||
metadata: Dict[str, Any],
|
||||
per_user_splits: Dict[str, Dict[str, List[Dict[str, Any]]]],
|
||||
) -> None:
|
||||
"""
|
||||
Write a new dataset to the directory layout expected by DataLoader and
|
||||
ShuffledDataLoader. Call this once; then point config data_path to output_path.
|
||||
|
||||
Args:
|
||||
output_path: Root directory to create (e.g. /path/to/MyDataset).
|
||||
metadata: info.json contents: "task", "class" (dict), "feature" (str).
|
||||
per_user_splits: { user_id: { "train": [samples], "test": [samples] } }.
|
||||
Each sample is a dict with "label" and "features".
|
||||
"""
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
with open(os.path.join(output_path, "info.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
for user_id, splits in per_user_splits.items():
|
||||
user_dir = os.path.join(output_path, str(user_id))
|
||||
os.makedirs(os.path.join(user_dir, "1"), exist_ok=True)
|
||||
os.makedirs(os.path.join(user_dir, "2"), exist_ok=True)
|
||||
train_ds = datasets.Dataset.from_list(splits.get("train", []))
|
||||
test_ds = datasets.Dataset.from_list(splits.get("test", []))
|
||||
train_ds.save_to_disk(os.path.join(user_dir, "1"))
|
||||
test_ds.save_to_disk(os.path.join(user_dir, "2"))
|
||||
|
||||
|
||||
# PPGBP data loader
|
||||
|
||||
"""
|
||||
@@ -94,13 +213,160 @@ PPGBP Dataset Loader
|
||||
|
||||
A data loader/iterator for the PPG-BP dataset that reads xlsx metadata
|
||||
and corresponding signal text files, with optional batching support.
|
||||
|
||||
Metadata embeddings live under PPGBP_METADATA_EMBEDDINGS_ROOT (one .npy per
|
||||
subject ID). If missing, they are generated automatically in PPGBPLoader.__init__
|
||||
using SBERT. Set PPGBP_METADATA_EMBEDDINGS_ROOT in .env or environment.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Iterator, Tuple, Dict, List, Optional, Union
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
_ppgbp_env_loaded = False
|
||||
def _ensure_ppgbp_env():
|
||||
global _ppgbp_env_loaded
|
||||
if not _ppgbp_env_loaded:
|
||||
load_dotenv()
|
||||
_ppgbp_env_loaded = True
|
||||
except ImportError:
|
||||
def _ensure_ppgbp_env():
|
||||
pass
|
||||
|
||||
|
||||
def _get_ppgbp_embedding_root() -> str:
|
||||
_ensure_ppgbp_env()
|
||||
root = os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT")
|
||||
if not root:
|
||||
raise ValueError(
|
||||
"PPGBP_METADATA_EMBEDDINGS_ROOT is not set. Set it in .env or environment "
|
||||
"(e.g. PPGBP_METADATA_EMBEDDINGS_ROOT=/path/to/metadata_embeddings)."
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def _get_ppgbp_embedding_root_optional() -> Optional[str]:
|
||||
"""Return PPGBP_METADATA_EMBEDDINGS_ROOT if set, else None. Used when embeddings are not required (e.g. recruiter with feature_root)."""
|
||||
_ensure_ppgbp_env()
|
||||
root = os.environ.get("PPGBP_METADATA_EMBEDDINGS_ROOT") or ""
|
||||
return root.strip() or None
|
||||
|
||||
|
||||
def _get_missing_embedding_ids(embedding_root: str, subject_ids: np.ndarray) -> Optional[List[int]]:
|
||||
"""
|
||||
Return list of subject_ids that do not have a .npy file in embedding_root.
|
||||
Return None if embedding_root is not a directory (so caller can create and generate).
|
||||
"""
|
||||
if not os.path.isdir(embedding_root):
|
||||
return None
|
||||
missing = []
|
||||
for sid in subject_ids:
|
||||
path = os.path.join(embedding_root, f"{int(sid)}.npy")
|
||||
if not os.path.isfile(path):
|
||||
missing.append(int(sid))
|
||||
return missing if missing else []
|
||||
|
||||
|
||||
def _check_embedding_root_populated(embedding_root: str, subject_ids: np.ndarray) -> None:
|
||||
"""Raise if embedding_root is missing or does not contain .npy for every subject_id."""
|
||||
missing_or_none = _get_missing_embedding_ids(embedding_root, subject_ids)
|
||||
if missing_or_none is None:
|
||||
raise FileNotFoundError(
|
||||
f"Metadata embeddings root is not a directory or does not exist: {embedding_root}."
|
||||
)
|
||||
if missing_or_none:
|
||||
raise FileNotFoundError(
|
||||
f"Metadata embeddings root {embedding_root} is missing .npy files for subject IDs: "
|
||||
f"{missing_or_none[:10]}{'...' if len(missing_or_none) > 10 else ''}."
|
||||
)
|
||||
|
||||
|
||||
def _generate_and_save_ppgbp_metadata_embeddings(
|
||||
embedding_root: str,
|
||||
subject_path: str,
|
||||
subject_ids: np.ndarray,
|
||||
) -> None:
|
||||
"""
|
||||
Generate SBERT metadata embeddings for the given subject IDs and save as
|
||||
{subject_id}.npy under embedding_root. Uses tqdm for progress.
|
||||
"""
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
tqdm = lambda x, **kw: x # noqa: E731
|
||||
|
||||
# Lazy import to keep SBERT/gen_plot deps out of normal import path
|
||||
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if _repo_root not in sys.path:
|
||||
sys.path.insert(0, _repo_root)
|
||||
from analysis.user_similarity.sbert_metadata_ppgbp.gen_plot import (
|
||||
load_subject_metadata,
|
||||
SBERT_Metadata,
|
||||
_normalize_sex_for_sbert,
|
||||
_normalize_hypertension_for_sbert,
|
||||
)
|
||||
|
||||
os.makedirs(embedding_root, exist_ok=True)
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
embedder = SBERT_Metadata()
|
||||
|
||||
ids_to_generate = [int(sid) for sid in subject_ids]
|
||||
for user_id in tqdm(ids_to_generate, desc="PPGBP metadata embeddings", unit="subject"):
|
||||
user_id_str = str(user_id)
|
||||
meta = subject_metadata.get(user_id_str, {})
|
||||
sex = _normalize_sex_for_sbert(meta.get("sex"))
|
||||
age = meta.get("age")
|
||||
height = meta.get("height")
|
||||
weight = meta.get("weight")
|
||||
sbp = meta.get("sbp")
|
||||
dbp = meta.get("dbp")
|
||||
hr = meta.get("hr")
|
||||
bmi = meta.get("bmi")
|
||||
hypertension = _normalize_hypertension_for_sbert(meta.get("hypertension"))
|
||||
if age is not None and isinstance(age, float) and np.isnan(age):
|
||||
age = None
|
||||
|
||||
emb = embedder.compute_embedding_from_metadata(
|
||||
[sex], [age], [height], [weight], [sbp], [dbp], [hr], [bmi], [hypertension]
|
||||
)
|
||||
path = os.path.join(embedding_root, f"{user_id}.npy")
|
||||
np.save(path, emb[0])
|
||||
|
||||
|
||||
def load_ppgbp_metadata_embedding(embedding_root: str, subject_id: int) -> np.ndarray:
|
||||
"""Load a single metadata embedding for subject_id from embedding_root (utility)."""
|
||||
path = os.path.join(embedding_root, f"{int(subject_id)}.npy")
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Metadata embedding not found: {path}")
|
||||
return np.load(path)
|
||||
|
||||
|
||||
def save_ppgbp_metadata_embeddings(
|
||||
embedding_root: str,
|
||||
subject_path: str,
|
||||
subject_ids: Optional[np.ndarray] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate and save PPGBP metadata embeddings under embedding_root (one .npy per subject).
|
||||
If subject_ids is None, uses all subject IDs present in the metadata file at subject_path.
|
||||
Uses tqdm for progress.
|
||||
"""
|
||||
if subject_ids is not None and len(subject_ids) == 0:
|
||||
return
|
||||
# If subject_ids is None we need to get all from metadata; _generate_* expects an array
|
||||
if subject_ids is None:
|
||||
_repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if _repo_root not in sys.path:
|
||||
sys.path.insert(0, _repo_root)
|
||||
from analysis.user_similarity.sbert_metadata_ppgbp.gen_plot import load_subject_metadata
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
subject_ids = np.array(sorted(int(k) for k in subject_metadata.keys()))
|
||||
_generate_and_save_ppgbp_metadata_embeddings(embedding_root, subject_path, subject_ids)
|
||||
|
||||
|
||||
class PPGBPLoader:
|
||||
"""
|
||||
@@ -128,7 +394,8 @@ class PPGBPLoader:
|
||||
batch_size: Optional[int] = None,
|
||||
shuffle: bool = False,
|
||||
num_segments: int = 3,
|
||||
seed: int = 42
|
||||
seed: int = 42,
|
||||
return_metadata_embeddings: bool = False,
|
||||
):
|
||||
self.base_dir = base_dir
|
||||
self.split = split
|
||||
@@ -136,6 +403,7 @@ class PPGBPLoader:
|
||||
self.shuffle = shuffle
|
||||
self.num_segments = num_segments
|
||||
self.seed = seed
|
||||
self.return_metadata_embeddings = return_metadata_embeddings
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
self.signal_dir = os.path.join(base_dir, "0_subject")
|
||||
@@ -147,6 +415,15 @@ class PPGBPLoader:
|
||||
self.subject_ids = self._get_split_ids()
|
||||
self.metadata = self.metadata[self.metadata['subject_ID'].isin(self.subject_ids)]
|
||||
|
||||
self._embedding_root = _get_ppgbp_embedding_root_optional()
|
||||
if self._embedding_root is not None:
|
||||
missing = _get_missing_embedding_ids(self._embedding_root, self._all_subject_ids)
|
||||
if len(self._all_subject_ids) > 0 and (missing is None or len(missing) > 0):
|
||||
_generate_and_save_ppgbp_metadata_embeddings(
|
||||
self._embedding_root, self.xlsx_path, self._all_subject_ids
|
||||
)
|
||||
_check_embedding_root_populated(self._embedding_root, self.subject_ids)
|
||||
|
||||
self._current_idx = 0
|
||||
self._indices = np.arange(len(self.subject_ids))
|
||||
|
||||
@@ -229,6 +506,9 @@ class PPGBPLoader:
|
||||
row = self.metadata[self.metadata['subject_ID'] == subject_id].iloc[0]
|
||||
return row.to_dict()
|
||||
|
||||
def _load_metadata_embedding(self, subject_id: int) -> np.ndarray:
|
||||
return load_ppgbp_metadata_embedding(self._embedding_root, subject_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.subject_ids)
|
||||
|
||||
@@ -238,7 +518,12 @@ class PPGBPLoader:
|
||||
self.rng.shuffle(self._indices)
|
||||
return self
|
||||
|
||||
def __next__(self) -> Union[Tuple[Dict, np.ndarray], Tuple[List[Dict], List[np.ndarray]]]:
|
||||
def __next__(self) -> Union[
|
||||
Tuple[Dict, np.ndarray],
|
||||
Tuple[Dict, np.ndarray, np.ndarray],
|
||||
Tuple[List[Dict], List[np.ndarray]],
|
||||
Tuple[List[Dict], List[np.ndarray], List[np.ndarray]],
|
||||
]:
|
||||
if self._current_idx >= len(self.subject_ids):
|
||||
raise StopIteration
|
||||
|
||||
@@ -248,6 +533,8 @@ class PPGBPLoader:
|
||||
metadata = self._get_metadata_dict(subject_id)
|
||||
signal = self._load_signal(subject_id)
|
||||
self._current_idx += 1
|
||||
if self.return_metadata_embeddings and self._embedding_root is not None:
|
||||
return metadata, signal, self._load_metadata_embedding(subject_id)
|
||||
return metadata, signal
|
||||
else:
|
||||
end_idx = min(self._current_idx + self.batch_size, len(self.subject_ids))
|
||||
@@ -255,28 +542,37 @@ class PPGBPLoader:
|
||||
|
||||
metadata_list = []
|
||||
signal_list = []
|
||||
embedding_list = [] if (self.return_metadata_embeddings and self._embedding_root is not None) else None
|
||||
|
||||
for idx in batch_indices:
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata_list.append(self._get_metadata_dict(subject_id))
|
||||
signal_list.append(self._load_signal(subject_id))
|
||||
if self.return_metadata_embeddings and self._embedding_root is not None:
|
||||
embedding_list.append(self._load_metadata_embedding(subject_id))
|
||||
|
||||
self._current_idx = end_idx
|
||||
if embedding_list is not None:
|
||||
return metadata_list, signal_list, embedding_list
|
||||
return metadata_list, signal_list
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[Dict, np.ndarray]:
|
||||
def __getitem__(self, idx: int) -> Union[Tuple[Dict, np.ndarray], Tuple[Dict, np.ndarray, np.ndarray]]:
|
||||
if idx < 0 or idx >= len(self.subject_ids):
|
||||
raise IndexError(f"Index {idx} out of range")
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata = self._get_metadata_dict(subject_id)
|
||||
signal = self._load_signal(subject_id)
|
||||
if self.return_metadata_embeddings and self._embedding_root is not None:
|
||||
return metadata, signal, self._load_metadata_embedding(subject_id)
|
||||
return metadata, signal
|
||||
|
||||
def get_by_subject_id(self, subject_id: int) -> Tuple[Dict, np.ndarray]:
|
||||
def get_by_subject_id(self, subject_id: int) -> Union[Tuple[Dict, np.ndarray], Tuple[Dict, np.ndarray, np.ndarray]]:
|
||||
if subject_id not in self.subject_ids:
|
||||
raise ValueError(f"Subject ID {subject_id} not found in this split.")
|
||||
metadata = self._get_metadata_dict(subject_id)
|
||||
signal = self._load_signal(subject_id)
|
||||
if self.return_metadata_embeddings and self._embedding_root is not None:
|
||||
return metadata, signal, self._load_metadata_embedding(subject_id)
|
||||
return metadata, signal
|
||||
|
||||
def reset(self):
|
||||
@@ -284,7 +580,9 @@ class PPGBPLoader:
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(self._indices)
|
||||
|
||||
def iter_batches(self, batch_size: int) -> Iterator[Tuple[List[Dict], List[np.ndarray]]]:
|
||||
def iter_batches(
|
||||
self, batch_size: int
|
||||
) -> Iterator[Union[Tuple[List[Dict], List[np.ndarray]], Tuple[List[Dict], List[np.ndarray], List[np.ndarray]]]]:
|
||||
indices = np.arange(len(self.subject_ids))
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(indices)
|
||||
@@ -295,13 +593,19 @@ class PPGBPLoader:
|
||||
|
||||
metadata_list = []
|
||||
signal_list = []
|
||||
embedding_list = [] if (self.return_metadata_embeddings and self._embedding_root is not None) else None
|
||||
|
||||
for idx in batch_indices:
|
||||
subject_id = self.subject_ids[idx]
|
||||
metadata_list.append(self._get_metadata_dict(subject_id))
|
||||
signal_list.append(self._load_signal(subject_id))
|
||||
if self.return_metadata_embeddings and self._embedding_root is not None:
|
||||
embedding_list.append(self._load_metadata_embedding(subject_id))
|
||||
|
||||
yield metadata_list, signal_list
|
||||
if embedding_list is not None:
|
||||
yield metadata_list, signal_list, embedding_list
|
||||
else:
|
||||
yield metadata_list, signal_list
|
||||
|
||||
|
||||
def get_loaders(
|
||||
|
||||
@@ -151,4 +151,38 @@ class Queue:
|
||||
"avg_survival": sum(durations) / len(durations),
|
||||
"max_survival": max(durations),
|
||||
"avg_usage": sum(usages) / len(usages) if usages else 0
|
||||
}
|
||||
}
|
||||
|
||||
def update_with_recruiter(self, results, new_example_sets, L: int = 1):
|
||||
"""
|
||||
Evict L lowest-scoring sets (by self-certainty from results) and push L new sets.
|
||||
results: list of (case, response_dict_or_text, self_certainty_score) in queue order.
|
||||
new_example_sets: list of L new cases (each case = list of indices, same as queue items).
|
||||
Lower self_certainty = better; we evict the L with highest (worst) score.
|
||||
"""
|
||||
if not results or L <= 0:
|
||||
return
|
||||
current = list(self._queue)
|
||||
if len(current) == 0:
|
||||
for case in new_example_sets[: self._queue.maxlen]:
|
||||
self._queue.append(list(case))
|
||||
self._register_stats(case)
|
||||
return
|
||||
scores = []
|
||||
for i, item in enumerate(current):
|
||||
sc = -1.0
|
||||
if i < len(results) and len(results[i]) >= 3:
|
||||
sc = float(results[i][2])
|
||||
scores.append((i, sc))
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
to_evict_idx = {scores[j][0] for j in range(min(L, len(scores)))}
|
||||
new_queue = [current[i] for i in range(len(current)) if i not in to_evict_idx]
|
||||
for idx in sorted(to_evict_idx, reverse=True):
|
||||
if idx < len(current):
|
||||
self._record_eviction_stats(current[idx])
|
||||
for case in new_example_sets[:L]:
|
||||
if len(new_queue) >= self._queue.maxlen:
|
||||
break
|
||||
new_queue.append(list(case))
|
||||
self._register_stats(case)
|
||||
self._queue = deque(new_queue, maxlen=self._queue.maxlen)
|
||||
487
sc/core/recruiter_agent.py
Normal file
487
sc/core/recruiter_agent.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
RecruiterAgent: MAB-based example set selection for personalized ICL.
|
||||
|
||||
Components:
|
||||
- self_certainty(logits): KL-div of LLM output from uniform (reward signal)
|
||||
- borda_vote(results): aggregate answers across example sets
|
||||
- FeatureIndex: loads index.json + embeddings.npy for O(1) lookup
|
||||
- MABSelector: contextual bandit adapted from CASE/sample_case.py top_m_arm
|
||||
- ArmPool: manages candidate example sets and feature vectors
|
||||
- RecruiterAgent: orchestrates everything
|
||||
|
||||
Feature file schema:
|
||||
<feature_root>/index.json + <feature_root>/embeddings.npy
|
||||
|
||||
index.json:
|
||||
{
|
||||
"version": "1.0",
|
||||
"embedding_type": "sbert_metadata",
|
||||
"embedding_dim": 384,
|
||||
"samples": [
|
||||
{"row": 0, "user_id": "001", "sample_id": 42, "label": "Hypertension", "session": "1"},
|
||||
...
|
||||
]
|
||||
}
|
||||
embeddings.npy: float32 array shape (N_samples, embedding_dim)
|
||||
Switch embedding type by pointing to a different feature_root directory.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import collections, json, math, os, random
|
||||
from random import sample as random_sample
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
_TORCH = True
|
||||
except ImportError:
|
||||
_TORCH = False
|
||||
|
||||
|
||||
# ── Utilities ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def self_certainty(logits) -> float:
|
||||
"""KL divergence of LLM logits from uniform. Lower = more certain. Negate for MAB."""
|
||||
if logits is None or not _TORCH:
|
||||
return 0.0
|
||||
logits = logits.squeeze(0).float()
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
score = (-1.0 / logits.shape[-1]) * log_probs.sum(dim=-1) - math.log(logits.shape[-1])
|
||||
return score.mean().item()
|
||||
|
||||
|
||||
def borda_vote(
|
||||
results: List[Tuple[int, Optional[Dict[str, Any]], float]],
|
||||
valid_classes: Optional[List[str]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Borda-count vote over (arm_idx, response_dict, certainty_score).
|
||||
Lower certainty_score = more certain = higher Borda rank.
|
||||
"""
|
||||
valid = [
|
||||
(r[1]["ANSWER"], r[2]) for r in results
|
||||
if r[1] is not None and "ANSWER" in r[1]
|
||||
and (valid_classes is None or r[1]["ANSWER"] in valid_classes)
|
||||
]
|
||||
if not valid:
|
||||
return None
|
||||
sorted_v = sorted(valid, key=lambda x: x[1])
|
||||
M = len(sorted_v)
|
||||
borda: Dict[str, float] = {}
|
||||
for rank, (ans, _) in enumerate(sorted_v):
|
||||
borda[ans] = borda.get(ans, 0.0) + (M - rank)
|
||||
top = max(borda.values())
|
||||
cands = [a for a, s in borda.items() if s == top]
|
||||
if len(cands) == 1:
|
||||
return cands[0]
|
||||
means = {a: float(np.mean([sc for x, sc in valid if x == a])) for a in cands}
|
||||
return min(means, key=means.get)
|
||||
|
||||
|
||||
# ── FeatureIndex ───────────────────────────────────────────────────────────────
|
||||
|
||||
class FeatureIndex:
|
||||
"""Loads embeddings from disk; O(1) lookup by (user_id, sample_id)."""
|
||||
|
||||
def __init__(self, feature_root: str) -> None:
|
||||
idx_p = os.path.join(feature_root, "index.json")
|
||||
emb_p = os.path.join(feature_root, "embeddings.npy")
|
||||
if not os.path.isfile(idx_p):
|
||||
raise FileNotFoundError(f"index.json not found: {idx_p}")
|
||||
if not os.path.isfile(emb_p):
|
||||
raise FileNotFoundError(f"embeddings.npy not found: {emb_p}")
|
||||
with open(idx_p, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
self.embedding_type: str = meta.get("embedding_type", "unknown")
|
||||
self.embedding_dim: int = meta["embedding_dim"]
|
||||
self.embeddings: np.ndarray = np.load(emb_p).astype(np.float32)
|
||||
self.samples: List[Dict[str, Any]] = meta["samples"]
|
||||
self._lookup: Dict[Tuple[str, int], int] = {
|
||||
(str(e["user_id"]), int(e["sample_id"])): int(e["row"])
|
||||
for e in self.samples
|
||||
}
|
||||
print(f"[FeatureIndex] {len(self.embeddings)} embeddings "
|
||||
f"(dim={self.embedding_dim}, type={self.embedding_type})")
|
||||
|
||||
def get(self, user_id: str, sample_id: int) -> Optional[np.ndarray]:
|
||||
row = self._lookup.get((str(user_id), int(sample_id)))
|
||||
return self.embeddings[row] if row is not None else None
|
||||
|
||||
def get_user_embedding(self, user_id: str) -> Optional[np.ndarray]:
|
||||
"""Mean embedding across all samples for a user."""
|
||||
rows = [self.embeddings[e["row"]] for e in self.samples
|
||||
if str(e["user_id"]) == str(user_id)]
|
||||
return np.mean(rows, axis=0).astype(np.float32) if rows else None
|
||||
|
||||
|
||||
# ── MABSelector ────────────────────────────────────────────────────────────────
|
||||
|
||||
class MABSelector:
|
||||
"""
|
||||
Contextual bandit for top-m arm identification.
|
||||
Adapted from CASE / sample_case.py top_m_arm.
|
||||
|
||||
Key differences from original:
|
||||
- Rewards injected via observe(arm, reward) — no LLM calls, no file IO
|
||||
- update() does one rank-1 theta update per call
|
||||
- No initialization() warm-up loop required
|
||||
|
||||
X: np.ndarray shape (embedding_dim, num_arms)
|
||||
m: number of top arms to track (= queue_size)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, X: np.ndarray, m: int,
|
||||
n_das: int = 5, delta: float = 0.05,
|
||||
epsilon: float = 0.11, sigma: float = 0.5,
|
||||
) -> None:
|
||||
assert X.ndim == 2, "X must be (embedding_dim, num_arms)"
|
||||
self.X = X.astype(np.float64)
|
||||
self.N, self.K = X.shape
|
||||
self.m = min(m, self.K)
|
||||
self.n_das = n_das
|
||||
self.delta = delta
|
||||
self.epsilon = epsilon
|
||||
self.sigma = sigma
|
||||
self.arms = list(range(self.K))
|
||||
self._it = -1
|
||||
self._Bdi: Dict[Tuple[int, int], float] = {}
|
||||
self._reset()
|
||||
|
||||
def _reset(self) -> None:
|
||||
self.t = 0
|
||||
self.rewards: List[float] = []
|
||||
self.pulled_arms: List[int] = []
|
||||
self.means = np.zeros(self.K)
|
||||
self.na = np.zeros(self.K)
|
||||
self.B_inv = np.eye(self.N, dtype=np.float64)
|
||||
self.b = np.zeros(self.N, dtype=np.float64)
|
||||
self.theta = np.random.normal(0, 1, size=(1, self.N))
|
||||
self.J: List[int] = []
|
||||
self.notJ: List[int] = list(range(self.K))
|
||||
self.N_t: List[int] = []
|
||||
self.best_arm: Optional[int] = None
|
||||
self.worst_arm: Optional[int] = None
|
||||
self.challenger: Optional[int] = None
|
||||
self.c_t: Optional[int] = None
|
||||
self.patience = 0
|
||||
self.previous_J: List[int] = []
|
||||
|
||||
# ── math helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _mn(self, x: np.ndarray, A: np.ndarray) -> float:
|
||||
x = np.asarray(x).flatten()
|
||||
return float(np.sqrt(max(float(x @ A @ x), 0.0)))
|
||||
|
||||
def _iu(self, A: np.ndarray, x: np.ndarray) -> np.ndarray:
|
||||
"""Rank-1 inverse update: A - A x xT A / (1 + xT A x)."""
|
||||
x = np.asarray(x).flatten()
|
||||
d = 1.0 + self._mn(x, A) ** 2
|
||||
return A - np.outer(A @ x, x @ A) / d
|
||||
|
||||
def _beta(self) -> float:
|
||||
return math.log((math.log(max(self.t, 2)) + 1) / max(self.delta, 1e-12))
|
||||
|
||||
def _gap(self, i: int, j: Optional[int] = None) -> float:
|
||||
xi = self.X[:, i] - (self.X[:, j] if j is not None else 0.0)
|
||||
return float(self.theta.flatten() @ xi)
|
||||
|
||||
def _var(self, i: int, j: Optional[int] = None) -> float:
|
||||
S = (self.sigma ** 2) * self.B_inv
|
||||
xi = self.X[:, i] - (self.X[:, j] if j is not None else 0.0)
|
||||
return self._mn(xi, S) * math.sqrt(2 * self._beta())
|
||||
|
||||
def _Bij(self, i: int, j: int) -> float:
|
||||
t = max(self.t, 1)
|
||||
if t != self._it:
|
||||
self._Bdi = {}
|
||||
self._it = t
|
||||
k = (i, j)
|
||||
if k not in self._Bdi:
|
||||
self._Bdi[k] = self._gap(i, j) + self._var(i, j)
|
||||
return self._Bdi[k]
|
||||
|
||||
def _randf(self, x: List[float], f) -> int:
|
||||
arr = np.array(x, dtype=float)
|
||||
val = f(arr)
|
||||
c = np.argwhere(arr == val).flatten().tolist()
|
||||
return random_sample(c, 1)[0]
|
||||
|
||||
def _mmax(self, x: List[float], m: int) -> List[int]:
|
||||
x = list(x)
|
||||
ids = []
|
||||
for _ in range(min(m, len(x))):
|
||||
idx = self._randf(x, np.max)
|
||||
ids.append(idx)
|
||||
x[idx] = -float("inf")
|
||||
return ids
|
||||
|
||||
# ── sample step ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _sel(self, nt: Optional[int]) -> Optional[int]:
|
||||
if nt is None or self.c_t is None:
|
||||
return None
|
||||
if self.means[self.c_t] > self.means[nt]:
|
||||
s = self.c_t
|
||||
if s in self.N_t:
|
||||
self.N_t.remove(s)
|
||||
return s
|
||||
return None
|
||||
|
||||
def _Jt(self) -> List[int]:
|
||||
if self.t == 0:
|
||||
return self._mmax(self.means.tolist(), self.m)
|
||||
nt = self.worst_arm
|
||||
sel = self._sel(nt)
|
||||
self.previous_J = list(self.J)
|
||||
if sel is not None:
|
||||
self.J = [a for a in self.J if a != nt] + [sel]
|
||||
if nt is not None and nt not in self.N_t:
|
||||
self.N_t.append(nt)
|
||||
order = np.argsort(self.means[self.J])[::-1].tolist()
|
||||
self.J = [self.J[i] for i in order]
|
||||
self.patience = 0 if self.J != self.previous_J else self.patience + 1
|
||||
return self.J
|
||||
|
||||
def sample(self) -> List[int]:
|
||||
"""Select next arm(s) to evaluate. Returns list of arm indices."""
|
||||
self.J = self._Jt()
|
||||
self.notJ = [a for a in self.arms if a not in self.J]
|
||||
|
||||
if self.t == 0 or not self.N_t:
|
||||
n = min(self.n_das, len(self.notJ))
|
||||
self.N_t = list(np.random.choice(self.notJ, n, replace=False)) if self.notJ else []
|
||||
else:
|
||||
Qt = list(np.random.choice(self.notJ, min(self.n_das, len(self.notJ)), replace=False)) if self.notJ else []
|
||||
self.N_t = list({*self.N_t, *Qt} - set(self.J))
|
||||
if self.N_t:
|
||||
top_n = min(self.n_das, len(self.N_t))
|
||||
tv = np.sort(self.means[self.N_t])[::-1][:top_n]
|
||||
picked: List[int] = []
|
||||
for mv in tv:
|
||||
c = [a for a in self.N_t if self.means[a] == mv and a not in picked]
|
||||
if c:
|
||||
picked.append(random.choice(c))
|
||||
self.N_t = picked
|
||||
|
||||
if self.N_t and self.J:
|
||||
jm = self.means[self.J]
|
||||
ntc = [a for a in self.J if self.means[a] == jm.min()]
|
||||
self.worst_arm = random.choice(ntc)
|
||||
bti = [self._Bij(a, self.worst_arm) for a in self.J]
|
||||
self.best_arm = self.J[self._randf(bti, np.max)]
|
||||
chi = [self._Bij(a, self.best_arm) for a in self.N_t]
|
||||
self.challenger = self.N_t[self._randf(chi, np.max)]
|
||||
nm = self.means[self.N_t]
|
||||
c = [a for a in self.N_t if self.means[a] == nm.max()]
|
||||
self.c_t = random.choice(c)
|
||||
else:
|
||||
self.worst_arm = self.J[0] if self.J else None
|
||||
self.best_arm = self.challenger = self.c_t = None
|
||||
|
||||
# Greedy arm pull
|
||||
if self.best_arm is not None and self.challenger is not None:
|
||||
d = self.X[:, self.best_arm] - self.X[:, self.challenger]
|
||||
u = [self._mn(d, self._iu(self.B_inv, self.X[:, i])) for i in self.arms]
|
||||
return [self.arms[self._randf(u, np.min)]]
|
||||
return [random.choice(self.arms)]
|
||||
|
||||
def observe(self, arm: int, reward: float) -> None:
|
||||
"""Record reward for arm."""
|
||||
self.rewards.append(reward)
|
||||
self.pulled_arms.append(arm)
|
||||
self.na[arm] += 1
|
||||
self.t += 1
|
||||
|
||||
def update(self) -> None:
|
||||
"""Rank-1 theta/B_inv update from the last observation."""
|
||||
if not self.pulled_arms:
|
||||
return
|
||||
x = self.X[:, self.pulled_arms[-1]].flatten()
|
||||
self.B_inv = self._iu(self.B_inv, x)
|
||||
self.b += self.rewards[-1] * x
|
||||
self.theta = (self.B_inv @ self.b).reshape(1, self.N)
|
||||
self.means = (self.theta @ self.X).flatten()
|
||||
|
||||
def recommend(self, n: int) -> List[int]:
|
||||
"""Return n arm indices with highest estimated mean rewards."""
|
||||
return self._mmax(self.means.tolist(), min(n, self.K))
|
||||
|
||||
def stopping_rule(self) -> bool:
|
||||
if None in (self.best_arm, self.worst_arm, self.challenger):
|
||||
return False
|
||||
return round(self._Bij(self.challenger, self.best_arm), 2) <= self.epsilon
|
||||
|
||||
|
||||
# ── ArmPool ────────────────────────────────────────────────────────────────────
|
||||
#[TODO] Add way to sample arms from the larger pool with replacement, instead of fixing on to num_arms arms.
|
||||
class ArmPool:
|
||||
"""
|
||||
Manages candidate example sets (arms).
|
||||
Each arm = list of sample dicts (N-way x k_shot).
|
||||
Arm feature = mean embedding of its constituent samples.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_samples: List[Dict[str, Any]],
|
||||
classes: List[str],
|
||||
k_shot: int,
|
||||
feature_index: Optional[FeatureIndex],
|
||||
num_arms: int,
|
||||
rng: np.random.Generator,
|
||||
) -> None:
|
||||
self.classes = classes
|
||||
self.k_shot = k_shot
|
||||
self.feature_index = feature_index
|
||||
self._rng = rng
|
||||
self._by: Dict[str, List[Dict[str, Any]]] = {c: [] for c in classes}
|
||||
for s in source_samples:
|
||||
lbl = s.get("label")
|
||||
if lbl in self._by:
|
||||
self._by[lbl].append(s)
|
||||
self.arms: List[List[Dict[str, Any]]] = []
|
||||
self.arm_feats: List[Optional[np.ndarray]] = []
|
||||
for _ in range(num_arms):
|
||||
arm, feat = self._make()
|
||||
self.arms.append(arm)
|
||||
self.arm_feats.append(feat)
|
||||
print(f"[ArmPool] {len(self.arms)} arms ({len(classes)}-way {k_shot}-shot)")
|
||||
|
||||
def _make(self) -> Tuple[List[Dict[str, Any]], Optional[np.ndarray]]:
|
||||
s: List[Dict[str, Any]] = []
|
||||
for c in self.classes:
|
||||
pool = self._by.get(c, [])
|
||||
if pool:
|
||||
n = min(self.k_shot, len(pool))
|
||||
idxs = self._rng.choice(len(pool), size=n, replace=False)
|
||||
s.extend([pool[i] for i in idxs])
|
||||
return s, self._feat(s)
|
||||
|
||||
def _feat(self, samples: List[Dict[str, Any]]) -> Optional[np.ndarray]:
|
||||
if self.feature_index is None:
|
||||
return None
|
||||
vecs = []
|
||||
for s in samples:
|
||||
v = self.feature_index.get(
|
||||
str(s.get("user_id", "")),
|
||||
int(s.get("sample_id", s.get("idx", -1))),
|
||||
)
|
||||
if v is not None:
|
||||
vecs.append(v)
|
||||
return np.mean(vecs, axis=0).astype(np.float32) if vecs else None
|
||||
|
||||
def get_arm(self, i: int) -> List[Dict[str, Any]]:
|
||||
return self.arms[i]
|
||||
|
||||
def build_X(self) -> np.ndarray:
|
||||
"""Feature matrix X of shape (embedding_dim, num_arms). Falls back to identity."""
|
||||
feats = [f for f in self.arm_feats if f is not None]
|
||||
if not feats:
|
||||
return np.eye(len(self.arms), dtype=np.float32)
|
||||
X = np.zeros((feats[0].shape[0], len(self.arms)), dtype=np.float32)
|
||||
for i, f in enumerate(self.arm_feats):
|
||||
if f is not None:
|
||||
X[:, i] = f
|
||||
return X
|
||||
|
||||
|
||||
# ── RecruiterAgent ────────────────────────────────────────────────────────────
|
||||
|
||||
class RecruiterAgent:
|
||||
"""
|
||||
MAB-driven recruiter for personalized ICL example set selection.
|
||||
|
||||
Usage:
|
||||
recruiter = RecruiterAgent(source_samples, classes, ...)
|
||||
initial = recruiter.recruit(queue_size) # cold start (random)
|
||||
|
||||
for test_sample in target_stream:
|
||||
results = [(arm_idx, response_dict, score), ...]
|
||||
recruiter.update(results) # MAB update lives here
|
||||
new_sets = recruiter.recruit(L) # MAB-guided new sets
|
||||
queue.update_with_recruiter(results, new_sets)
|
||||
|
||||
MAB update logic inside update():
|
||||
1. reward = -score (negate: lower certainty = better = higher reward)
|
||||
2. mab.observe(arm, reward)
|
||||
3. mab.update() -> rank-1 theta/B_inv update
|
||||
4. mab.sample() -> select next arms to explore (stored for next recruit())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_samples: List[Dict[str, Any]],
|
||||
classes: List[str],
|
||||
feature_index: Optional[FeatureIndex] = None,
|
||||
target_user_id: Optional[str] = None,
|
||||
k_shot: int = 1,
|
||||
queue_size: int = 5,
|
||||
num_arms: int = 50,
|
||||
n_das: int = 5,
|
||||
delta: float = 0.05,
|
||||
epsilon: float = 0.11,
|
||||
sigma: float = 0.5,
|
||||
seed: int = 42,
|
||||
) -> None:
|
||||
self.classes = classes
|
||||
self.k_shot = k_shot
|
||||
self.queue_size = queue_size
|
||||
self.target_user_id = target_user_id
|
||||
|
||||
rng = np.random.default_rng(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
self.pool = ArmPool(source_samples, classes, k_shot, feature_index, num_arms, rng)
|
||||
self.mab = MABSelector(
|
||||
self.pool.build_X(), m=queue_size,
|
||||
n_das=n_das, delta=delta, epsilon=epsilon, sigma=sigma,
|
||||
)
|
||||
self._next: Optional[List[int]] = None
|
||||
self._initialized = False
|
||||
print(f"[RecruiterAgent] {len(self.pool.arms)} arms, "
|
||||
f"queue={queue_size}, classes={list(set(classes))}")
|
||||
|
||||
def recruit(self, n: int) -> List[Tuple[int, List[Dict[str, Any]]]]:
|
||||
"""
|
||||
Return n (arm_idx, example_set) pairs for insertion into the queue.
|
||||
Cold start: random. Subsequent calls: MAB-recommended.
|
||||
"""
|
||||
if not self._initialized:
|
||||
indices = random.sample(range(len(self.pool.arms)), min(n, len(self.pool.arms)))
|
||||
self._initialized = True
|
||||
elif self._next is not None:
|
||||
|
||||
indices = self._next[:n]
|
||||
if len(indices) < n:
|
||||
rest = [i for i in range(len(self.pool.arms)) if i not in indices]
|
||||
indices += random.sample(rest, min(n - len(indices), len(rest)))
|
||||
else:
|
||||
indices = self.mab.recommend(n)
|
||||
return [(i, self.pool.get_arm(i)) for i in indices]
|
||||
|
||||
def update(self, results: List[Tuple[int, Optional[Dict[str, Any]], float]]) -> None:
|
||||
"""
|
||||
Update MAB from (arm_idx, response_dict, self_certainty_score).
|
||||
|
||||
Steps:
|
||||
1. reward = -score (lower certainty = better = higher reward)
|
||||
2. mab.observe(arm, reward)
|
||||
3. mab.update() -> rank-1 theta/B_inv update
|
||||
4. mab.sample() -> selects next arms to explore
|
||||
"""
|
||||
if not results:
|
||||
return
|
||||
for arm_idx, _, score in results:
|
||||
if 0 <= arm_idx < len(self.pool.arms):
|
||||
self.mab.observe(arm_idx, -float(score))
|
||||
self.mab.update()
|
||||
try:
|
||||
next_cands = self.mab.sample()
|
||||
top = self.mab.recommend(self.queue_size)
|
||||
self._next = list(dict.fromkeys(next_cands + top))
|
||||
except Exception as e:
|
||||
print(f"[RecruiterAgent] MAB sample error ({e}), using recommend()")
|
||||
self._next = self.mab.recommend(self.queue_size)
|
||||
99
sc/preprocess/build_feature_index.py
Normal file
99
sc/preprocess/build_feature_index.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Build unified feature index from PPGBP per-subject .npy embeddings.
|
||||
Output: <output_dir>/index.json + <output_dir>/embeddings.npy
|
||||
|
||||
Usage:
|
||||
python -m sc.preprocess.build_feature_index --embedding_root $ROOT --output_dir ./features/sbert_metadata
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _load_labels(base_dir: Optional[str]) -> Dict[int, str]:
|
||||
"""Load subject_id -> label from xlsx. Prefers Hypertension column (PPGBP), else class/Class."""
|
||||
if not base_dir:
|
||||
return {}
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
return {}
|
||||
xlsx = None
|
||||
for f in os.listdir(base_dir):
|
||||
if f.endswith(".xlsx"):
|
||||
xlsx = os.path.join(base_dir, f)
|
||||
break
|
||||
if not xlsx or not os.path.isfile(xlsx):
|
||||
return {}
|
||||
df = pd.read_excel(xlsx, header=1)
|
||||
if "subject_ID" not in df.columns:
|
||||
return {}
|
||||
label_col = None
|
||||
for c in ("Hypertension", "class", "Class"):
|
||||
if c in df.columns:
|
||||
label_col = c
|
||||
break
|
||||
out = {}
|
||||
for _, row in df.iterrows():
|
||||
sid = int(row["subject_ID"])
|
||||
out[sid] = str(row.get(label_col, "unknown")).strip() if label_col else "unknown"
|
||||
return out
|
||||
|
||||
|
||||
def build_index(
|
||||
embedding_root: str,
|
||||
output_dir: str,
|
||||
embedding_type: str = "sbert_metadata",
|
||||
base_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
if not os.path.isdir(embedding_root):
|
||||
raise FileNotFoundError(embedding_root)
|
||||
subject_ids = []
|
||||
for f in sorted(os.listdir(embedding_root)):
|
||||
if not f.endswith(".npy"):
|
||||
continue
|
||||
try:
|
||||
subject_ids.append(int(f.replace(".npy", "")))
|
||||
except ValueError:
|
||||
continue
|
||||
subject_ids = sorted(subject_ids)
|
||||
if not subject_ids:
|
||||
raise ValueError("No .npy files in " + embedding_root)
|
||||
labels = _load_labels(base_dir)
|
||||
embeddings_list = []
|
||||
samples: List[Dict[str, Any]] = []
|
||||
for row, sid in enumerate(subject_ids):
|
||||
emb = np.load(os.path.join(embedding_root, f"{sid}.npy")).astype(np.float32)
|
||||
embeddings_list.append(emb)
|
||||
samples.append({
|
||||
"row": row,
|
||||
"user_id": str(sid),
|
||||
"sample_id": sid,
|
||||
"label": labels.get(sid, "unknown"),
|
||||
"session": "1",
|
||||
})
|
||||
embeddings = np.stack(embeddings_list, axis=0)
|
||||
dim = int(embeddings.shape[1])
|
||||
meta = {"version": "1.0", "embedding_type": embedding_type, "embedding_dim": dim, "samples": samples}
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
with open(os.path.join(output_dir, "index.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
np.save(os.path.join(output_dir, "embeddings.npy"), embeddings)
|
||||
print(f"[build_feature_index] Wrote {len(samples)} samples, dim={dim}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--embedding_root", required=True, help="Directory of per-subject .npy files")
|
||||
ap.add_argument("--output_dir", required=True, help="Output dir for index.json and embeddings.npy")
|
||||
ap.add_argument("--embedding_type", default="sbert_metadata")
|
||||
ap.add_argument("--base_dir", default=None, help="Optional PPGBP data dir (xlsx) for labels")
|
||||
a = ap.parse_args()
|
||||
build_index(a.embedding_root, a.output_dir, a.embedding_type, a.base_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
134
sc/preprocess/build_ppgbp_embeddings.py
Normal file
134
sc/preprocess/build_ppgbp_embeddings.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Build PPGBP embeddings under PPGBP_EMBEDDINGS_ROOT with two top-level dirs:
|
||||
metadata_embs/<emb_type>/<subject_id>.npy
|
||||
sig_embs/<emb_type>/<subject_id>.npy
|
||||
|
||||
For now only metadata embeddings are produced: for each subject, take the xlsx row,
|
||||
vectorize (one-hot for Sex and Hypertension, numeric for the rest) and save.
|
||||
|
||||
Usage:
|
||||
python -m sc.preprocess.build_ppgbp_embeddings \\
|
||||
--embeddings_root /path/to/ppgbp_embeddings \\
|
||||
--data_dir /path/to/ppgbp_data \\
|
||||
--metadata_emb_type onehot
|
||||
|
||||
Then build the feature index (for downstream RecruiterAgent etc.):
|
||||
python -m sc.preprocess.build_feature_index \\
|
||||
--embedding_root $PPGBP_EMBEDDINGS_ROOT/metadata_embs/onehot \\
|
||||
--output_dir $DATA_INDEX_ROOT/metadata_onehot \\
|
||||
--embedding_type metadata_onehot \\
|
||||
--base_dir $PPGBP_DATA_DIR
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# Xlsx column names (header=1)
|
||||
SUBJECT_ID_COL = "subject_ID"
|
||||
SEX_COL = "Sex(M/F)"
|
||||
AGE_COL = "Age(year)"
|
||||
HEIGHT_COL = "Height(cm)"
|
||||
WEIGHT_COL = "Weight(kg)"
|
||||
SBP_COL = "Systolic Blood Pressure(mmHg)"
|
||||
DBP_COL = "Diastolic Blood Pressure(mmHg)"
|
||||
HR_COL = "Heart Rate(b/m)"
|
||||
BMI_COL = "BMI(kg/m^2)"
|
||||
HYP_COL = "Hypertension"
|
||||
|
||||
NUMERIC_COLS = [AGE_COL, HEIGHT_COL, WEIGHT_COL, SBP_COL, DBP_COL, HR_COL, BMI_COL]
|
||||
CATEGORICAL_COLS = [SEX_COL, HYP_COL]
|
||||
|
||||
|
||||
def _find_xlsx(data_dir: str) -> str:
|
||||
for f in os.listdir(data_dir):
|
||||
if f.endswith(".xlsx"):
|
||||
return os.path.join(data_dir, f)
|
||||
raise FileNotFoundError(f"No .xlsx in {data_dir}")
|
||||
|
||||
|
||||
def _onehot_and_numeric(df: pd.DataFrame) -> Tuple[np.ndarray, List[str], Dict[str, List[str]]]:
|
||||
"""
|
||||
One-hot encode categoricals (Sex, Hypertension) and stack with numeric columns.
|
||||
Returns (matrix, column_names, categories_used).
|
||||
"""
|
||||
out_cols: List[str] = []
|
||||
pieces: List[np.ndarray] = []
|
||||
categories_used: Dict[str, List[str]] = {}
|
||||
|
||||
for col in CATEGORICAL_COLS:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
vals = df[col].fillna("unknown").astype(str).str.strip()
|
||||
uniq = sorted(vals.unique())
|
||||
categories_used[col] = uniq
|
||||
for u in uniq:
|
||||
out_cols.append(f"{col}__{u}")
|
||||
onehot = (vals.values.reshape(-1, 1) == np.array(uniq)).astype(np.float32)
|
||||
pieces.append(onehot)
|
||||
|
||||
for col in NUMERIC_COLS:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
out_cols.append(col)
|
||||
vec = df[col].fillna(0).astype(np.float32).values.reshape(-1, 1)
|
||||
pieces.append(vec)
|
||||
|
||||
if not pieces:
|
||||
raise ValueError("No columns found; check xlsx has expected column names")
|
||||
X = np.hstack(pieces)
|
||||
return X, out_cols, categories_used
|
||||
|
||||
|
||||
def create_embedding_dirs(embeddings_root: str, metadata_emb_type: str, sig_emb_type: str = "placeholder") -> None:
|
||||
"""Create metadata_embs/<emb_type>/ and sig_embs/<emb_type>/."""
|
||||
os.makedirs(embeddings_root, exist_ok=True)
|
||||
meta_dir = os.path.join(embeddings_root, "metadata_embs", metadata_emb_type)
|
||||
sig_dir = os.path.join(embeddings_root, "sig_embs", sig_emb_type)
|
||||
os.makedirs(meta_dir, exist_ok=True)
|
||||
os.makedirs(sig_dir, exist_ok=True)
|
||||
|
||||
|
||||
def build_metadata_embeddings_onehot(
|
||||
data_dir: str,
|
||||
embeddings_root: str,
|
||||
metadata_emb_type: str = "onehot",
|
||||
) -> None:
|
||||
"""
|
||||
Read xlsx from data_dir, vectorize each subject row (one-hot + numeric), save as
|
||||
metadata_embs/<metadata_emb_type>/<subject_id>.npy.
|
||||
"""
|
||||
xlsx_path = _find_xlsx(data_dir)
|
||||
df = pd.read_excel(xlsx_path, header=1)
|
||||
if SUBJECT_ID_COL not in df.columns:
|
||||
raise ValueError(f"xlsx must have column {SUBJECT_ID_COL}")
|
||||
|
||||
X, _cols, _cats = _onehot_and_numeric(df)
|
||||
subject_ids = df[SUBJECT_ID_COL].astype(int).values
|
||||
out_dir = os.path.join(embeddings_root, "metadata_embs", metadata_emb_type)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
for i, sid in enumerate(subject_ids):
|
||||
path = os.path.join(out_dir, f"{int(sid)}.npy")
|
||||
np.save(path, X[i].astype(np.float32))
|
||||
print(f"[build_ppgbp_embeddings] Wrote {len(subject_ids)} metadata embeddings to {out_dir} (dim={X.shape[1]})")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description="Build PPGBP metadata (and optionally signal) embeddings under embeddings_root.")
|
||||
ap.add_argument("--embeddings_root", required=True, help="Root for metadata_embs/ and sig_embs/ (e.g. PPGBP_EMBEDDINGS_ROOT)")
|
||||
ap.add_argument("--data_dir", required=True, help="PPGBP data dir containing the xlsx file")
|
||||
ap.add_argument("--metadata_emb_type", default="onehot", help="Subdir name under metadata_embs/ (e.g. onehot)")
|
||||
ap.add_argument("--sig_emb_type", default="placeholder", help="Subdir under sig_embs/ (created empty for now)")
|
||||
args = ap.parse_args()
|
||||
create_embedding_dirs(args.embeddings_root, args.metadata_emb_type, args.sig_emb_type)
|
||||
build_metadata_embeddings_onehot(args.data_dir, args.embeddings_root, args.metadata_emb_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
76
sc/preprocess/package_ppgbp_embeddings.py
Normal file
76
sc/preprocess/package_ppgbp_embeddings.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
One-shot: build PPGBP one-hot metadata embeddings and then package them into
|
||||
the feature index at DATA_INDEX_ROOT for downstream use (RecruiterAgent, etc.).
|
||||
|
||||
Step 1: Create metadata_embs/onehot/ and sig_embs/ under PPGBP_EMBEDDINGS_ROOT,
|
||||
write one-hot metadata vectors per subject.
|
||||
Step 2: Run build_feature_index to produce index.json + embeddings.npy at
|
||||
DATA_INDEX_ROOT (compatible with FeatureIndex and RecruiterAgent).
|
||||
|
||||
Usage:
|
||||
export PPGBP_EMBEDDINGS_ROOT=/scratch/.../embeddings_full/ppgbp
|
||||
export PPGBP_DATA_DIR=/path/to/ppgbp_data # dir with xlsx and 0_subject/
|
||||
# or: export PPGBP_DATA_ROOT=/path/to/data/tsllm (same as data dir if xlsx lives there)
|
||||
export DATA_INDEX_ROOT=/path/to/feature_index_output
|
||||
|
||||
python -m sc.preprocess.package_ppgbp_embeddings
|
||||
|
||||
Or with explicit args:
|
||||
python -m sc.preprocess.package_ppgbp_embeddings \\
|
||||
--embeddings_root $PPGBP_EMBEDDINGS_ROOT \\
|
||||
--data_dir $PPGBP_DATA_DIR \\
|
||||
--index_output_dir $DATA_INDEX_ROOT \\
|
||||
--metadata_emb_type onehot \\
|
||||
--embedding_type metadata_onehot
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description="Build PPGBP metadata embeddings and package feature index.")
|
||||
ap.add_argument("--embeddings_root", default=None, help="Default: PPGBP_EMBEDDINGS_ROOT env")
|
||||
ap.add_argument("--data_dir", default=None, help="PPGBP data dir (xlsx + 0_subject/); default: PPGBP_DATA_DIR or PPGBP_DATA_ROOT env")
|
||||
ap.add_argument("--index_output_dir", default=None, help="Where to write index.json + embeddings.npy; default: DATA_INDEX_ROOT env")
|
||||
ap.add_argument("--metadata_emb_type", default="onehot")
|
||||
ap.add_argument("--embedding_type", default="metadata_onehot", help="embedding_type string in index.json")
|
||||
args = ap.parse_args()
|
||||
|
||||
embeddings_root = args.embeddings_root or os.environ.get("PPGBP_EMBEDDINGS_ROOT")
|
||||
data_dir = args.data_dir or os.environ.get("PPGBP_DATA_DIR") or os.environ.get("PPGBP_DATA_ROOT") or embeddings_root
|
||||
index_output_dir = args.index_output_dir or os.environ.get("DATA_INDEX_ROOT")
|
||||
|
||||
if not embeddings_root:
|
||||
print("Set PPGBP_EMBEDDINGS_ROOT or pass --embeddings_root", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not os.path.isdir(data_dir):
|
||||
print(f"Data dir not found: {data_dir}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not index_output_dir:
|
||||
print("Set DATA_INDEX_ROOT or pass --index_output_dir", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
from sc.preprocess.build_ppgbp_embeddings import (
|
||||
create_embedding_dirs,
|
||||
build_metadata_embeddings_onehot,
|
||||
)
|
||||
create_embedding_dirs(embeddings_root, args.metadata_emb_type)
|
||||
build_metadata_embeddings_onehot(data_dir, embeddings_root, args.metadata_emb_type)
|
||||
|
||||
embedding_root_for_index = os.path.join(embeddings_root, "metadata_embs", args.metadata_emb_type)
|
||||
from sc.preprocess.build_feature_index import build_index
|
||||
build_index(
|
||||
embedding_root=embedding_root_for_index,
|
||||
output_dir=index_output_dir,
|
||||
embedding_type=args.embedding_type,
|
||||
base_dir=data_dir,
|
||||
)
|
||||
print(f"[package_ppgbp_embeddings] Feature index written to {index_output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
390
sc/run_recruiter.py
Normal file
390
sc/run_recruiter.py
Normal file
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Recruiter-based Self-Consistency runner for PPGBP.
|
||||
|
||||
Uses RecruiterAgent (MAB) to select example sets, self_certainty(logits) as reward,
|
||||
and borda_vote for aggregation. Requires PPGBPLoader and optional feature index.
|
||||
|
||||
Usage:
|
||||
python -m sc.run_recruiter sc/config/ppgbp_recruiter.yaml
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from fire import Fire
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sc.core.data_loader import PPGBPLoader, get_loaders
|
||||
from sc.core.example_queue import Queue
|
||||
from sc.core.model import Model, load_models
|
||||
from sc.core.recruiter_agent import RecruiterAgent, borda_vote, self_certainty
|
||||
from sc.core.recruiter_agent import FeatureIndex
|
||||
|
||||
|
||||
def _ppgbp_sample_to_dict(metadata: Dict, signal, subject_id: int, label_key: str = "label") -> Dict[str, Any]:
|
||||
"""Turn (metadata, signal) into a sample dict with user_id, sample_id, label, features."""
|
||||
# Try to find label in metadata using label_key, then "hypertension", then "Hypertension", then "class", then "Class"
|
||||
label = str(metadata.get(label_key))
|
||||
if label == "None" or label == "unknown":
|
||||
for key in ["hypertension", "Hypertension", "class", "Class"]:
|
||||
val = metadata.get(key)
|
||||
if val is not None:
|
||||
label = str(val)
|
||||
break
|
||||
if label == "None":
|
||||
label = "unknown"
|
||||
|
||||
features = {k: str(v) for k, v in metadata.items() if k != label_key and k.lower() != "label"}
|
||||
return {
|
||||
"user_id": str(subject_id),
|
||||
"sample_id": subject_id,
|
||||
"idx": subject_id,
|
||||
"label": label,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
def _build_prompt(system_msg: str, task_info: str, classes_info: List[str], test_sample: Dict, examples: List[Dict]) -> List[Dict[str, str]]:
|
||||
"""Build chat messages for the model."""
|
||||
parts = [f"**Task**: {task_info}\n\n**Classes**: {', '.join(classes_info)}\n\n"]
|
||||
parts.append("**Examples**\n")
|
||||
for ex in examples:
|
||||
parts.append(f"*Example of {ex['label']}*:\n")
|
||||
for k, v in (ex.get("features") or {}).items():
|
||||
parts.append(f" - {k}: {v}\n")
|
||||
parts.append("\n**Current sample**\n")
|
||||
for k, v in (test_sample.get("features") or {}).items():
|
||||
parts.append(f" - {k}: {v}\n")
|
||||
parts.append(f"\nRespond in JSON: {{\"REASON\": \"...\", \"CONFIDENCE\": 0.0-1.0, \"ANSWER\": \"<class>\"}}")
|
||||
content = "".join(parts)
|
||||
return [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": content},
|
||||
]
|
||||
|
||||
|
||||
def _parse_answer(text: str, classes: List[str]) -> Optional[str]:
|
||||
"""Extract ANSWER from JSON in text."""
|
||||
try:
|
||||
start = text.find("{")
|
||||
end = text.rfind("}") + 1
|
||||
if start >= 0 and end > start:
|
||||
obj = json.loads(text[start:end])
|
||||
ans = (obj.get("ANSWER") or "").strip()
|
||||
if ans in classes:
|
||||
return ans
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def make_gt_scorer(
|
||||
num_arms: int,
|
||||
best_arm_id: int = 20,
|
||||
sigma: float = 5.0,
|
||||
) -> Callable[[int], float]:
|
||||
"""
|
||||
Ground-truth: deterministic. The distribution *over arms* is normal: reward(arm_id)
|
||||
is a Gaussian in arm index with peak at best_arm_id. So reward(i) = exp(-(i - best)^2 / (2 sigma^2)).
|
||||
Score = -reward so lower = better. No randomness.
|
||||
"""
|
||||
best_arm_id = max(0, min(best_arm_id, num_arms - 1))
|
||||
sigma = max(1e-6, float(sigma))
|
||||
|
||||
def scorer(arm_id: int) -> float:
|
||||
if arm_id < 0 or arm_id >= num_arms:
|
||||
return 0.0
|
||||
reward = np.exp(-((arm_id - best_arm_id) ** 2) / (2 * sigma**2))
|
||||
return -float(reward)
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
async def run_single_user(
|
||||
config: Dict[str, Any],
|
||||
model_pool: Optional[Any],
|
||||
test_loader: Optional[PPGBPLoader],
|
||||
train_loader: PPGBPLoader,
|
||||
seed: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run recruiter pipeline. If model_pool is None, use ground-truth scorer (arm_id -> score) for MAB convergence."""
|
||||
queue_size = config.get("queue_size", 5)
|
||||
recruit_size = config.get("recruit_size", 2)
|
||||
n_way = config.get("n_way", 2)
|
||||
k_shot = config.get("k_shot", 1)
|
||||
task_info = config.get("task_info", "Classify the subject.")
|
||||
classes_info = config.get("classes_info", ["Normal", "Hypertension"])
|
||||
system_msg = config.get("system_message", "You are a helpful assistant. Respond in JSON with REASON, CONFIDENCE, ANSWER.")
|
||||
use_gt_only = model_pool is None
|
||||
|
||||
np.random.seed(seed)
|
||||
train_loader._indices = np.arange(len(train_loader.subject_ids))
|
||||
if train_loader.shuffle:
|
||||
train_loader.rng.shuffle(train_loader._indices)
|
||||
|
||||
example_dataset = []
|
||||
for idx in train_loader._indices:
|
||||
sid = train_loader.subject_ids[idx]
|
||||
meta = train_loader._get_metadata_dict(sid)
|
||||
signal = train_loader._load_signal(sid)
|
||||
example_dataset.append(_ppgbp_sample_to_dict(meta, signal, int(sid)))
|
||||
if len(example_dataset) < n_way * k_shot:
|
||||
print("[run_recruiter] Not enough train samples")
|
||||
return []
|
||||
|
||||
classes = list({s["label"] for s in example_dataset})
|
||||
if not classes:
|
||||
classes = classes_info
|
||||
class_indices = {c: [i for i, s in enumerate(example_dataset) if s["label"] == c] for c in classes}
|
||||
for c in classes:
|
||||
if not class_indices[c]:
|
||||
class_indices[c] = [0]
|
||||
dataset_idx = {(str(s["user_id"]), s["sample_id"]): i for i, s in enumerate(example_dataset)}
|
||||
|
||||
feature_index = None
|
||||
if config.get("feature_root") and os.path.isdir(config["feature_root"]):
|
||||
try:
|
||||
feature_index = FeatureIndex(config["feature_root"])
|
||||
except Exception as e:
|
||||
print(f"[run_recruiter] Feature index load failed: {e}")
|
||||
|
||||
recruiter = RecruiterAgent(
|
||||
source_samples=example_dataset,
|
||||
classes=classes,
|
||||
feature_index=feature_index,
|
||||
target_user_id=None,
|
||||
k_shot=k_shot,
|
||||
queue_size=queue_size,
|
||||
num_arms=config.get("num_arms", 50),
|
||||
n_das=config.get("n_das", 5),
|
||||
delta=config.get("delta", 0.05),
|
||||
epsilon=config.get("epsilon", 0.11),
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
num_arms = len(recruiter.pool.arms)
|
||||
if use_gt_only:
|
||||
gt_steps = config.get("gt_steps", 200)
|
||||
gt_best_arm_id = config.get("gt_best_arm_id", 20)
|
||||
gt_sigma = config.get("gt_sigma", 5.0)
|
||||
gt_scorer = make_gt_scorer(num_arms, best_arm_id=gt_best_arm_id, sigma=gt_sigma)
|
||||
print(f"[run_recruiter] GT-only mode: {gt_steps} steps, best_arm_id={gt_best_arm_id}, sigma={gt_sigma}")
|
||||
|
||||
initial_sets = recruiter.recruit(queue_size)
|
||||
initial_cases = []
|
||||
slot_arms = []
|
||||
for arm_idx, ex_set in initial_sets:
|
||||
case = []
|
||||
for d in ex_set:
|
||||
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
|
||||
if key in dataset_idx:
|
||||
case.append(dataset_idx[key])
|
||||
else:
|
||||
for i, s in enumerate(example_dataset):
|
||||
if s.get("user_id") == str(d.get("user_id")) and s.get("sample_id") == d.get("sample_id"):
|
||||
case.append(i)
|
||||
break
|
||||
if len(case) >= n_way:
|
||||
initial_cases.append(case[:n_way])
|
||||
slot_arms.append(arm_idx)
|
||||
while len(initial_cases) < queue_size and len(initial_sets) > len(initial_cases):
|
||||
arm_idx, ex_set = initial_sets[len(initial_cases)]
|
||||
case = []
|
||||
for d in ex_set:
|
||||
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
|
||||
case.append(dataset_idx.get(key, 0))
|
||||
if len(case) < n_way:
|
||||
case.extend([0] * (n_way - len(case)))
|
||||
initial_cases.append(case[:n_way])
|
||||
slot_arms.append(arm_idx)
|
||||
|
||||
ex_queue = Queue(class_indices, queue_size)
|
||||
ex_queue._queue = deque(initial_cases[:queue_size], maxlen=queue_size)
|
||||
for case in ex_queue._queue:
|
||||
ex_queue._register_stats(case)
|
||||
slot_arms = slot_arms[:queue_size]
|
||||
|
||||
results = []
|
||||
processed = 0
|
||||
cumulative_correct = 0
|
||||
|
||||
if use_gt_only:
|
||||
# Loop: query gt_scorer(arm_id) for each slot, update MAB, recruit, update queue
|
||||
for step in range(gt_steps):
|
||||
ex_queue.set_current_time(step)
|
||||
queue_cases = list(ex_queue._queue)
|
||||
run_results = []
|
||||
for slot_i in range(len(slot_arms)):
|
||||
arm_idx = slot_arms[slot_i]
|
||||
score = gt_scorer(arm_idx)
|
||||
response_dict = {"ANSWER": None, "REASON": "gt", "CONFIDENCE": 0.5}
|
||||
run_results.append((arm_idx, response_dict, score))
|
||||
|
||||
recruiter.update(run_results)
|
||||
new_sets = recruiter.recruit(recruit_size)
|
||||
new_cases = []
|
||||
new_arm_ids = []
|
||||
for arm_idx, ex_set in new_sets:
|
||||
case = []
|
||||
for d in ex_set:
|
||||
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
|
||||
case.append(dataset_idx.get(key, 0))
|
||||
if len(case) < n_way:
|
||||
case.extend([0] * (n_way - len(case)))
|
||||
new_cases.append(case[:n_way])
|
||||
new_arm_ids.append(arm_idx)
|
||||
ex_queue.update_with_recruiter(run_results, new_cases, recruit_size)
|
||||
scores_with_idx = [(i, run_results[i][2]) for i in range(len(run_results))]
|
||||
scores_with_idx.sort(key=lambda x: x[1], reverse=True)
|
||||
to_evict = {scores_with_idx[j][0] for j in range(min(recruit_size, len(scores_with_idx)))}
|
||||
kept = [i for i in range(len(slot_arms)) if i not in to_evict]
|
||||
slot_arms = [slot_arms[i] for i in kept] + new_arm_ids[:recruit_size]
|
||||
|
||||
results.append({
|
||||
"sample_idx": step,
|
||||
"answer": None,
|
||||
"ground_truth": None,
|
||||
"is_correct": None,
|
||||
"gt_step": step,
|
||||
"queue_arms": list(slot_arms),
|
||||
"step_rewards": [r[2] for r in run_results],
|
||||
})
|
||||
print(f"[GT Summary] Final queue arms: {slot_arms}")
|
||||
return results
|
||||
|
||||
# Model mode: iterate test loader, call LLM, self_certainty, borda_vote
|
||||
for step, (metadata, signal) in enumerate(test_loader):
|
||||
test_sample = _ppgbp_sample_to_dict(metadata, signal, int(metadata.get("subject_ID", step)))
|
||||
ex_queue.set_current_time(processed)
|
||||
queue_cases = list(ex_queue._queue)
|
||||
run_results = []
|
||||
for slot_i, case in enumerate(queue_cases):
|
||||
examples = [example_dataset[i] for i in case if i < len(example_dataset)]
|
||||
if len(examples) < n_way:
|
||||
continue
|
||||
messages = _build_prompt(system_msg, task_info, classes_info, test_sample, examples)
|
||||
text, logits = await model_pool.invoke(messages)
|
||||
score = self_certainty(logits)
|
||||
parsed = _parse_answer(text, classes_info)
|
||||
response_dict = {"ANSWER": parsed, "REASON": text[:200], "CONFIDENCE": 0.5}
|
||||
arm_idx = slot_arms[slot_i] if slot_i < len(slot_arms) else slot_i
|
||||
run_results.append((arm_idx, response_dict, score))
|
||||
|
||||
if not run_results:
|
||||
processed += 1
|
||||
continue
|
||||
answer = borda_vote(run_results, valid_classes=classes_info)
|
||||
gt = test_sample.get("label", "")
|
||||
is_correct = answer == gt
|
||||
cumulative_correct += 1 if is_correct else 0
|
||||
acc = cumulative_correct / (processed + 1)
|
||||
|
||||
recruiter.update(run_results)
|
||||
new_sets = recruiter.recruit(recruit_size)
|
||||
new_cases = []
|
||||
new_arm_ids = []
|
||||
for arm_idx, ex_set in new_sets:
|
||||
case = []
|
||||
for d in ex_set:
|
||||
key = (str(d.get("user_id", "")), d.get("sample_id", d.get("idx", -1)))
|
||||
case.append(dataset_idx.get(key, 0))
|
||||
if len(case) < n_way:
|
||||
case.extend([0] * (n_way - len(case)))
|
||||
new_cases.append(case[:n_way])
|
||||
new_arm_ids.append(arm_idx)
|
||||
ex_queue.update_with_recruiter(run_results, new_cases, recruit_size)
|
||||
scores_with_idx = [(i, run_results[i][2]) for i in range(len(run_results))]
|
||||
scores_with_idx.sort(key=lambda x: x[1], reverse=True)
|
||||
to_evict = {scores_with_idx[j][0] for j in range(min(recruit_size, len(scores_with_idx)))}
|
||||
kept = [i for i in range(len(slot_arms)) if i not in to_evict]
|
||||
slot_arms = [slot_arms[i] for i in kept] + new_arm_ids[:recruit_size]
|
||||
|
||||
results.append({
|
||||
"sample_idx": processed,
|
||||
"answer": answer,
|
||||
"ground_truth": gt,
|
||||
"is_correct": is_correct,
|
||||
"cumulative_accuracy": acc,
|
||||
})
|
||||
processed += 1
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _expand_env_in_config(obj: Any) -> Any:
|
||||
"""Recursively expand $VAR and ${VAR} in config strings."""
|
||||
if isinstance(obj, dict):
|
||||
return {k: _expand_env_in_config(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_expand_env_in_config(v) for v in obj]
|
||||
if isinstance(obj, str):
|
||||
return os.path.expandvars(obj)
|
||||
return obj
|
||||
|
||||
|
||||
def main(config_path: str) -> None:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
config = _expand_env_in_config(config)
|
||||
if not config.get("model_path") and config.get("models"):
|
||||
config["model_path"] = config["models"][0] if isinstance(config["models"], list) else config["models"]
|
||||
model_path = config.get("model_path") or ""
|
||||
if isinstance(model_path, str):
|
||||
model_path = model_path.strip()
|
||||
use_gt_only = not model_path
|
||||
if use_gt_only:
|
||||
config["use_gt_only"] = True
|
||||
print("[run_recruiter] model_path is null: using ground-truth scorer (arm_id -> score) for MAB convergence")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if "log_path" in config:
|
||||
config["log_path"] = f"{config['log_path']}_{timestamp}"
|
||||
|
||||
pool = None
|
||||
if not use_gt_only:
|
||||
pool = load_models(
|
||||
[model_path],
|
||||
temperature=config.get("temperature", 0.7),
|
||||
max_new_tokens=config.get("max_new_tokens", 256),
|
||||
)
|
||||
|
||||
data_path = config["data_path"]
|
||||
train_loader, test_loader = get_loaders(data_path, seed=config.get("seed", 42))
|
||||
if use_gt_only:
|
||||
test_loader = None # not used in GT loop
|
||||
seeds = range(config.get("num_seeds", 1))
|
||||
all_results = []
|
||||
for seed in seeds:
|
||||
res = asyncio.run(run_single_user(config, pool, test_loader, train_loader, seed))
|
||||
all_results.extend(res)
|
||||
|
||||
if use_gt_only:
|
||||
print(f"GT mode: completed {len(all_results)} steps")
|
||||
else:
|
||||
correct = sum(1 for r in all_results if r.get("is_correct"))
|
||||
total = len(all_results)
|
||||
acc = correct / total if total else 0
|
||||
print(f"Accuracy: {correct}/{total} = {acc:.4f}")
|
||||
log_path = config.get("log_path", "./logs/ppgbp_recruiter")
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
with open(os.path.join(log_path, "results.json"), "w") as f:
|
||||
json.dump({
|
||||
"use_gt_only": use_gt_only,
|
||||
"accuracy": (sum(1 for r in all_results if r.get("is_correct")) / len(all_results)) if all_results and not use_gt_only else None,
|
||||
"correct": sum(1 for r in all_results if r.get("is_correct")) if not use_gt_only else None,
|
||||
"total": len(all_results),
|
||||
"results": all_results,
|
||||
}, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
Reference in New Issue
Block a user