Compare commits
8 Commits
main
...
self-certa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4bbbf702d | ||
|
|
ccb88d4eef | ||
|
|
95ddd935f7 | ||
|
|
d10b70dc20 | ||
|
|
3f3db31d25 | ||
|
|
31f4af7106 | ||
|
|
8a26346b5e | ||
|
|
a7c8e43f89 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -3,6 +3,7 @@
|
||||
*.csv
|
||||
*.arrow
|
||||
*.json
|
||||
temp*/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
80
README.md
Normal file
80
README.md
Normal file
@@ -0,0 +1,80 @@
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
conda create -n tsllmpers python=3.12
|
||||
conda activate tsllmpers
|
||||
python -m pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Preprocessing
|
||||
|
||||
Raw data must be preprocessed into HuggingFace `datasets` format before running.
|
||||
A preprocessing script is provided for the SleepEDF dataset:
|
||||
|
||||
```bash
|
||||
python preprocess/preprocess_SleepEDF.py \
|
||||
--path /path/to/SleepEDF/raw/sleep-cassette/ \
|
||||
--out_dir /path/to/output/processed_SleepEDF \
|
||||
--num_workers 32
|
||||
```
|
||||
|
||||
The output directory will have the following structure:
|
||||
|
||||
```
|
||||
SleepEDF_new/
|
||||
task_metadata.json # task description, class definitions, data/feature info
|
||||
user_metadata.json # per-user metadata (age, sex)
|
||||
00/ # user folder (HuggingFace Dataset saved with save_to_disk)
|
||||
01/
|
||||
...
|
||||
```
|
||||
|
||||
Each user dataset contains the columns: `user_id` (str), `label` (str), `features` (dict of floats), and `data` (dict of raw signal arrays).
|
||||
|
||||
To preprocess a different dataset, write a similar script that produces the same output structure.
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
python run.py --config_path config/test.yaml
|
||||
```
|
||||
|
||||
## Config
|
||||
|
||||
All configuration is in a single YAML file. Example (`config/test.yaml`):
|
||||
|
||||
```yaml
|
||||
log_path: ./temp
|
||||
data_path: /path/to/data/processed_SleepEDF
|
||||
target_user: "00"
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
model_paths:
|
||||
- ollama:url:hostname:11437/gpt-oss:20b
|
||||
- ollama:url:hostname:11438/gpt-oss:20b
|
||||
vocab_size: 200064
|
||||
```
|
||||
|
||||
| Key | Description |
|
||||
|---|---|
|
||||
| `log_path` | Directory for logs and results (timestamped subfolder created automatically) |
|
||||
| `data_path` | Path to the preprocessed dataset directory |
|
||||
| `target_user` | User ID to evaluate on; all other users become source (ICL examples) |
|
||||
| `queue_size` | Number of example sets to maintain in the queue |
|
||||
| `num_shot` | Number of examples per class in each example set |
|
||||
| `model_paths` | List of Ollama model endpoints in `ollama:url:host:port/model` format |
|
||||
| `vocab_size` | Vocabulary size of the model (used for self-certainty scoring) |
|
||||
|
||||
## Ollama
|
||||
|
||||
The model backend is [Ollama](https://ollama.com). You need one or more Ollama servers running and accessible over HTTP.
|
||||
|
||||
Example scripts for managing multiple Ollama instances on a multi-GPU machine are provided in `utils/` for reference. These are environment-specific -- adapt the ports, model paths, and GPU assignments to your own setup:
|
||||
|
||||
```bash
|
||||
# Reference only -- edit before using
|
||||
bash utils/launch_ollamas.sh # start servers in tmux sessions
|
||||
bash utils/kill_ollamas.sh # stop all servers
|
||||
```
|
||||
|
||||
The model path format in the config is `ollama:url:<host>:<port>/<model_name>`.
|
||||
File diff suppressed because one or more lines are too long
@@ -1,96 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "a0874e1b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Correct: 1038, Total: 2149, Accuracy: 0.4830153559795254\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from glob import glob\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"correct = 0\n",
|
||||
"total = 0\n",
|
||||
"summary_paths = glob(\"/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF/*/*/*/summary.txt\")\n",
|
||||
"for summary_path in summary_paths:\n",
|
||||
" with open(summary_path, \"r\") as f:\n",
|
||||
" summary = f.read()\n",
|
||||
" if summary:\n",
|
||||
" answer = summary.split(\"Answer: \")[-1].split(\" (Ground truth: \")[0]\n",
|
||||
" ground_truth = summary.split(\" (Ground truth: \")[-1].split(\")\")[0]\n",
|
||||
" if answer == ground_truth:\n",
|
||||
" correct += 1\n",
|
||||
" total += 1\n",
|
||||
"print(f\"Correct: {correct}, Total: {total}, Accuracy: {correct/total}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "f78ffc6f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Correct: 932, Total: 2260, Accuracy: 0.41238938053097346\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"correct = 0\n",
|
||||
"total = 0\n",
|
||||
"summary_paths = glob(\"/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF_out_random/*/*/*/summary.txt\")\n",
|
||||
"for summary_path in summary_paths:\n",
|
||||
" with open(summary_path, \"r\") as f:\n",
|
||||
" summary = f.read()\n",
|
||||
" if summary:\n",
|
||||
" answer = summary.split(\"Answer: \")[-1].split(\" (Ground truth: \")[0]\n",
|
||||
" ground_truth = summary.split(\" (Ground truth: \")[-1].split(\")\")[0]\n",
|
||||
" if answer == ground_truth:\n",
|
||||
" correct += 1\n",
|
||||
" total += 1\n",
|
||||
"print(f\"Correct: {correct}, Total: {total}, Accuracy: {correct/total}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6872381f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "tsllmpers",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,580 +0,0 @@
|
||||
"""
|
||||
Chronos-2 Time Series Embedding Extraction and Visualization Pipeline
|
||||
|
||||
This module provides functionality to
|
||||
1. Extract embeddings from multivariate time series using Chronos-2 foundation model
|
||||
2. Visualize embeddings using dimensionality reduction (t-SNE)
|
||||
|
||||
Chronos-2 Overview:
|
||||
Chronos-2 is a time series foundation model developed by Amazon that converts
|
||||
time series forecasting into a language modeling task. It tokenizes time series
|
||||
values and generates probabilistic forecasts using a transformer architecture.
|
||||
|
||||
Key Features:
|
||||
- Zero-shot forecasting: Works on unseen time series without fine-tuning
|
||||
- Probabilistic predictions: Outputs quantile forecasts (e.g., 10%, 50%, 90%)
|
||||
- Multivariate support: Can process multiple channels simultaneously
|
||||
|
||||
Embedding Strategy:
|
||||
We use Chronos-2's internal encoder hidden states as embeddings, which is the
|
||||
recommended approach for representation learning. The encoder captures rich
|
||||
temporal patterns through self-attention mechanisms.
|
||||
|
||||
"encoder" : Uses encoder hidden states directly
|
||||
- More informative representation of input characteristics
|
||||
- Captures learned temporal patterns from pre-training
|
||||
|
||||
|
||||
Usage:
|
||||
# Extract embeddings
|
||||
python gen_plot.py extract --data_root /path/to/data --out_dir ./embeddings
|
||||
|
||||
# Visualize with t-SNE
|
||||
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
|
||||
|
||||
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||
Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from glob import glob
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from datasets import load_from_disk, Dataset
|
||||
from chronos import BaseChronosPipeline, Chronos2Pipeline
|
||||
from fire import Fire
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
# Add parent directories to path for importing metadata loader
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
from preprocess.sleepedf_metadata import SleepEDFMetadata, DEFAULT_METADATA_PATH
|
||||
|
||||
# pipeline: Chronos2Pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cuda")
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
# EEG channel names used in Sleep-EDF dataset
|
||||
# Fpz-Cz: Frontal-Central electrode pair
|
||||
# Pz-Oz: Parietal-Occipital electrode pair
|
||||
EEG_CHANNEL_1 = "EEG Fpz-Cz"
|
||||
EEG_CHANNEL_2 = "EEG Pz-Oz"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Extractor Class
|
||||
# =============================================================================
|
||||
|
||||
class Chronos_2_Embedder:
|
||||
"""
|
||||
Extracts fixed-dimensional embeddings from multivariate time series using Chronos-2.
|
||||
|
||||
Uses the internal encoder hidden states from Chronos-2's transformer.
|
||||
This directly accesses the model's learned features for representation learning.
|
||||
|
||||
Architecture:
|
||||
Input Time Series → Patching → Encoder (6 layers) → Hidden States → Pooling → Embedding
|
||||
|
||||
Processing Pipeline:
|
||||
1. Instance normalization (z-score per series)
|
||||
2. Patching (splits series into fixed-size patches)
|
||||
3. Patch embedding (linear projection to d_model dimensions)
|
||||
4. Transformer encoder (6 layers of self-attention)
|
||||
5. Output: hidden states of shape (n_variates, num_patches + 2, d_model)
|
||||
- +2 for [REG] token and masked output patch token
|
||||
|
||||
Pooling Strategies:
|
||||
- mean: Average across all patch tokens (excluding special tokens)
|
||||
- cls: Use the [REG] token embedding (similar to BERT's [CLS])
|
||||
|
||||
Attributes:
|
||||
pipeline: Chronos-2 model pipeline for inference
|
||||
pooling_strategy: How to pool encoder hidden states ("mean" or "cls")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "amazon/chronos-2",
|
||||
device_map: str = None,
|
||||
pooling_strategy: str = "mean",
|
||||
variate_fusion: str = "concat",
|
||||
):
|
||||
"""
|
||||
Initialize the Chronos-2 embedder.
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model name or local path
|
||||
device_map: Device placement ("cuda", "cpu", or None for auto)
|
||||
pooling_strategy: "mean" (average all patches) or "cls" (use [REG] token only)
|
||||
"""
|
||||
self.pooling_strategy = pooling_strategy
|
||||
|
||||
if variate_fusion not in ["concat", "mean"]:
|
||||
raise ValueError(
|
||||
f"Invalid variate_fusion: {variate_fusion}. Use 'concat' (1024) or 'mean' (512)."
|
||||
)
|
||||
self.variate_fusion = variate_fusion
|
||||
|
||||
if device_map is None:
|
||||
device_map = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device_map = device_map
|
||||
self.device = torch.device(
|
||||
"cuda" if device_map.startswith("cuda") and torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
# Load pre-trained Chronos-2 model
|
||||
print(f"[INFO] Loading Chronos-2 model: {model_name}")
|
||||
print(f"[INFO] Device: {device_map}")
|
||||
print(f"[INFO] Pooling strategy: {pooling_strategy}")
|
||||
print(f"[INFO] Variate fusion: {variate_fusion}")
|
||||
self.pipeline: Chronos2Pipeline = Chronos2Pipeline.from_pretrained(
|
||||
model_name,
|
||||
device_map=device_map
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
Discover all user/session directories under data_root.
|
||||
|
||||
Uses glob pattern matching for cleaner directory traversal.
|
||||
Expected structure: data_root/user_id/session_id/
|
||||
|
||||
Returns:
|
||||
List of (user_id, session_id, session_path) tuples
|
||||
"""
|
||||
discovered_paths = []
|
||||
|
||||
# Use glob to find all session directories (2 levels deep)
|
||||
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
|
||||
if not os.path.isdir(session_path):
|
||||
continue
|
||||
|
||||
# Extract user_id and session_id from path
|
||||
session_id = os.path.basename(session_path)
|
||||
user_id = os.path.basename(os.path.dirname(session_path))
|
||||
|
||||
discovered_paths.append((user_id, session_id, session_path))
|
||||
|
||||
return discovered_paths
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, batch: Dict[str, Any]) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding vectors using Chronos-2's internal encoder hidden states.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Parse batch to extract 2-channel EEG time series
|
||||
2. Format for Chronos-2 input (B, V, L)
|
||||
3. Call pipeline.embed() to get encoder hidden states
|
||||
4. Pool hidden states to get fixed-size embedding
|
||||
|
||||
Args:
|
||||
batch: HuggingFace dataset batch from slicing (dataset[start:end])
|
||||
Format: {"data": [{"EEG Fpz-Cz": [...], "EEG Pz-Oz": [...]}, ...]}
|
||||
|
||||
Returns:
|
||||
Embedding array of shape (batch_size, embedding_dim)
|
||||
For Chronos-2 with d_model=512:
|
||||
- pooling (mean/cls) produces (B, 512) per variate
|
||||
- variate_fusion='concat': (B, 2*512) = (B, 1024) for 2 channels
|
||||
- variate_fusion='mean' : (B, 512) by averaging across variates
|
||||
"""
|
||||
# =====================================================================
|
||||
# Step 1: Parse batch to multivariate time series array
|
||||
# =====================================================================
|
||||
samples = batch["data"]
|
||||
channel_1 = np.stack([np.asarray(s[EEG_CHANNEL_1], dtype=np.float32) for s in samples])
|
||||
channel_2 = np.stack([np.asarray(s[EEG_CHANNEL_2], dtype=np.float32) for s in samples])
|
||||
timeseries = np.stack([channel_1, channel_2], axis=-1) # (B, 3000, 2)
|
||||
|
||||
# =====================================================================
|
||||
# Step 2: Format for Chronos-2 input
|
||||
# =====================================================================
|
||||
# Chronos-2 expects (batch, n_variates, seq_length)
|
||||
x_input = np.transpose(timeseries, (0, 2, 1)).astype(np.float32) # (B, 2, 3000)
|
||||
|
||||
# =====================================================================
|
||||
# Step 3: Get encoder embeddings using pipeline.embed()
|
||||
# =====================================================================
|
||||
# embed() returns:
|
||||
# - embeddings: list of tensors, each (n_variates, num_patches + 2, d_model)
|
||||
# - loc_scale: list of tuples (loc, scale) for denormalization
|
||||
embeddings_list, loc_scale_list = self.pipeline.embed(x_input)
|
||||
|
||||
# =====================================================================
|
||||
# Step 4: Pool hidden states to get fixed-size embedding
|
||||
# =====================================================================
|
||||
all_embeddings = []
|
||||
|
||||
for emb in embeddings_list:
|
||||
# emb shape: (n_variates, num_patches + 2, d_model) = (2, N+2, 512)
|
||||
|
||||
if self.pooling_strategy == "cls":
|
||||
# Use the [REG] token (first token) as the embedding
|
||||
# This is similar to BERT's [CLS] token approach
|
||||
pooled = emb[:, 0, :] # (n_variates, d_model) = (2, 512)
|
||||
else: # "mean" pooling (default)
|
||||
# Average across all patch tokens (excluding special tokens)
|
||||
# Skip first token ([REG]) and last token (masked output patch)
|
||||
pooled = emb[:, 1:-1, :].mean(dim=1) # (n_variates, d_model) = (2, 512)
|
||||
|
||||
if self.variate_fusion == "mean":
|
||||
# Fuse variates by averaging their pooled representations
|
||||
# (2, 512) -> (512,)
|
||||
pooled_flat = pooled.mean(dim=0)
|
||||
else:
|
||||
# Keep each variate's representation and concatenate
|
||||
# (2, 512) -> (1024,)
|
||||
pooled_flat = pooled.reshape(-1)
|
||||
all_embeddings.append(pooled_flat.cpu().numpy())
|
||||
|
||||
return np.stack(all_embeddings, axis=0).astype(np.float32)
|
||||
|
||||
def extract_embeddings(
|
||||
self,
|
||||
data_root: str,
|
||||
batch_size: int = 32,
|
||||
metadata_path: str = DEFAULT_METADATA_PATH,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Extract embeddings from all sessions under the data root directory.
|
||||
|
||||
Iterates through all user/session combinations, processes time series
|
||||
in batches, and aggregates results with metadata.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session subfolders
|
||||
batch_size: Number of samples to process together.
|
||||
Larger = faster but more memory.
|
||||
32 is a good balance for most GPUs.
|
||||
metadata_path: Path to SC-subjects.xls for gender/age info
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with columns:
|
||||
- user_id, session_id, idx, label (metadata)
|
||||
- gender, age (demographic metadata)
|
||||
- embedding (vector; dim depends on variate_fusion: 1024 for concat, 512 for mean)
|
||||
"""
|
||||
session_paths = self.discover_session_paths(data_root)
|
||||
print(f"[INFO] Discovered {len(session_paths)} sessions")
|
||||
|
||||
# Load metadata for gender/age information
|
||||
try:
|
||||
metadata = SleepEDFMetadata(metadata_path)
|
||||
has_metadata = True
|
||||
except FileNotFoundError:
|
||||
print(f"[WARNING] Metadata file not found: {metadata_path}")
|
||||
print(f"[WARNING] Gender/age information will not be available")
|
||||
metadata = None
|
||||
has_metadata = False
|
||||
|
||||
all_embeddings = []
|
||||
all_user_ids = []
|
||||
all_session_ids = []
|
||||
all_idxs = []
|
||||
all_labels = []
|
||||
all_genders = []
|
||||
all_ages = []
|
||||
|
||||
for user_id, session_id, session_path in session_paths:
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(session_path)
|
||||
# shuffle dataset
|
||||
dataset = dataset.shuffle(seed=0)
|
||||
num_samples = len(dataset)
|
||||
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
|
||||
|
||||
# Process in batches to manage memory
|
||||
for batch_start in range(0, num_samples, batch_size):
|
||||
batch_end = min(batch_start + batch_size, num_samples)
|
||||
|
||||
# Slice dataset to get batch
|
||||
batch = dataset[batch_start:batch_end]
|
||||
|
||||
# Compute embeddings
|
||||
embeddings = self.compute_embedding(batch)
|
||||
|
||||
# Collect embeddings and metadata
|
||||
for i in range(embeddings.shape[0]):
|
||||
user_id_str = str(batch["user_id"][i])
|
||||
all_embeddings.append(embeddings[i].tolist())
|
||||
all_user_ids.append(user_id_str)
|
||||
all_session_ids.append(str(batch["session_id"][i]))
|
||||
all_idxs.append(int(batch["idx"][i]))
|
||||
all_labels.append(str(batch["label"][i]))
|
||||
|
||||
# Add demographic metadata
|
||||
if has_metadata:
|
||||
info = metadata.get_info(user_id_str)
|
||||
if info:
|
||||
all_genders.append(info['gender'])
|
||||
all_ages.append(info['age'])
|
||||
else:
|
||||
all_genders.append('Unknown')
|
||||
all_ages.append(-1)
|
||||
else:
|
||||
all_genders.append('Unknown')
|
||||
all_ages.append(-1)
|
||||
|
||||
# Create HuggingFace Dataset
|
||||
result_dataset = Dataset.from_dict({
|
||||
"user_id": all_user_ids,
|
||||
"session_id": all_session_ids,
|
||||
"idx": all_idxs,
|
||||
"label": all_labels,
|
||||
"gender": all_genders,
|
||||
"age": all_ages,
|
||||
"embedding": all_embeddings,
|
||||
})
|
||||
|
||||
print(f"[INFO] Total samples: {len(result_dataset)}")
|
||||
if has_metadata:
|
||||
gender_counts = {}
|
||||
for g in all_genders:
|
||||
gender_counts[g] = gender_counts.get(g, 0) + 1
|
||||
print(f"[INFO] Gender distribution in extracted data: {gender_counts}")
|
||||
return result_dataset
|
||||
|
||||
def save_embeddings(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
output_dir: str
|
||||
) -> None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
dataset.save_to_disk(output_dir)
|
||||
|
||||
print(f"[DONE] Saved embeddings dataset: {output_dir}")
|
||||
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
|
||||
|
||||
@staticmethod
|
||||
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||
dataset = load_from_disk(embedding_dir)
|
||||
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization Functions
|
||||
# =============================================================================
|
||||
|
||||
def reduce_to_2d_tsne(
|
||||
embeddings: np.ndarray,
|
||||
perplexity: float = 30.0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Reduce high-dimensional embeddings to 2D using t-SNE.
|
||||
|
||||
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
|
||||
dimensionality reduction technique that preserves local structure.
|
||||
Points that are similar in high dimensions stay close in 2D.
|
||||
|
||||
Args:
|
||||
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
|
||||
perplexity: t-SNE perplexity parameter (typically 5-50).
|
||||
Higher values consider more neighbors, creating smoother layouts.
|
||||
Rule of thumb: perplexity ~ sqrt(num_samples)
|
||||
|
||||
Returns:
|
||||
2D coordinates of shape (num_samples, 2)
|
||||
"""
|
||||
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
|
||||
|
||||
tsne = TSNE(
|
||||
n_components=2,
|
||||
random_state=0, # For reproducibility
|
||||
perplexity=perplexity,
|
||||
max_iter=1000, # Usually sufficient for convergence
|
||||
init='random',
|
||||
learning_rate='auto', # Let sklearn choose optimal learning rate
|
||||
)
|
||||
|
||||
return tsne.fit_transform(embeddings)
|
||||
|
||||
|
||||
def create_scatter_plot(
|
||||
coordinates: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
title: str,
|
||||
output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot with categorical coloring.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
labels: Category labels for each point (string array)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF vector format recommended)
|
||||
"""
|
||||
# Get unique labels for legend
|
||||
unique_labels = sorted(set(labels))
|
||||
|
||||
# Select colormap based on number of categories
|
||||
# tab10: 10 distinct colors, tab20: 20 distinct colors
|
||||
colormap = plt.cm.tab10 if len(unique_labels) <= 10 else plt.cm.tab20
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Plot each category separately for proper legend
|
||||
for idx, label in enumerate(unique_labels):
|
||||
mask = labels == label
|
||||
ax.scatter(
|
||||
coordinates[mask, 0],
|
||||
coordinates[mask, 1],
|
||||
c=[colormap(idx % 20)],
|
||||
s=15,
|
||||
label=label,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# Configure plot appearance
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
ax.legend(loc='best', markerscale=2)
|
||||
|
||||
# Save figure as vector PDF (scalable, ideal for publications)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
class CLI:
|
||||
def extract(
|
||||
self,
|
||||
data_root: str,
|
||||
out_dir: str,
|
||||
model: str = "amazon/chronos-2",
|
||||
batch_size: int = 32,
|
||||
pooling: str = "mean",
|
||||
variate_fusion: str = "concat",
|
||||
) -> None:
|
||||
"""
|
||||
Extract embeddings from time series data.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
out_dir: Output directory for HuggingFace dataset
|
||||
model: Chronos-2 model name or path (default: amazon/chronos-2)
|
||||
batch_size: Batch size for inference (default: 32)
|
||||
pooling: Pooling strategy - 'mean' or 'cls' (default: mean)
|
||||
variate_fusion: How to combine channels - 'concat' (1024) or 'mean' (512)
|
||||
"""
|
||||
# Validate pooling argument
|
||||
if pooling not in ["mean", "cls"]:
|
||||
raise ValueError(f"Invalid pooling strategy: {pooling}. Use 'mean' or 'cls'.")
|
||||
if variate_fusion not in ["concat", "mean"]:
|
||||
raise ValueError(
|
||||
f"Invalid variate_fusion: {variate_fusion}. Use 'concat' (1024) or 'mean' (512)."
|
||||
)
|
||||
|
||||
# Initialize embedder
|
||||
embedder = Chronos_2_Embedder(
|
||||
model_name=model,
|
||||
pooling_strategy=pooling,
|
||||
variate_fusion=variate_fusion,
|
||||
)
|
||||
|
||||
# Extract and save embeddings
|
||||
dataset = embedder.extract_embeddings(data_root, batch_size)
|
||||
embedder.save_embeddings(dataset, out_dir)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
emb_dir: str,
|
||||
out_dir: str,
|
||||
perplexity: float = 30.0,
|
||||
users: str = None,
|
||||
num_users: int = 0,
|
||||
labels: str = None,
|
||||
gender: str = None,
|
||||
metadata_path: str = DEFAULT_METADATA_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize embeddings with t-SNE.
|
||||
|
||||
Args:
|
||||
emb_dir: Directory containing the HuggingFace embeddings dataset
|
||||
out_dir: Output directory for visualization plots (PDF)
|
||||
perplexity: t-SNE perplexity parameter (default: 30.0)
|
||||
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||
num_users: Include only first N users, 0 = all (default: 0)
|
||||
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||
gender: Filter by gender ('M', 'F', or None for all)
|
||||
metadata_path: Path to metadata file for gender lookup (if not in dataset)
|
||||
"""
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Load saved embeddings dataset
|
||||
dataset = Chronos_2_Embedder.load_embeddings(emb_dir)
|
||||
|
||||
# Apply user filtering
|
||||
if users:
|
||||
user_list = [u.strip() for u in users.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
|
||||
print(f"[INFO] Filtered to users: {user_list}")
|
||||
|
||||
elif num_users > 0:
|
||||
all_users = sorted(set(dataset["user_id"]))
|
||||
selected_users = all_users[:num_users]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||
|
||||
# Filter by gender
|
||||
if gender:
|
||||
dataset = dataset.filter(lambda x: x.get("gender", "Unknown") == gender)
|
||||
print(f"[INFO] Filtered to gender: {gender}")
|
||||
|
||||
# Filter by sleep stage labels
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["label"] in label_list)
|
||||
print(f"[INFO] Filtered to labels: {label_list}")
|
||||
|
||||
print(f"[INFO] Total samples: {len(dataset)}")
|
||||
|
||||
# Print gender distribution
|
||||
if "gender" in dataset.column_names:
|
||||
gender_counts = {}
|
||||
for g in dataset["gender"]:
|
||||
gender_counts[g] = gender_counts.get(g, 0) + 1
|
||||
print(f"[INFO] Gender distribution: {gender_counts}")
|
||||
|
||||
# Extract embeddings as numpy array for t-SNE
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
|
||||
# Reduce to 2D with t-SNE
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||
|
||||
# Generate visualizations
|
||||
create_scatter_plot(
|
||||
coordinates_2d,
|
||||
np.array(dataset["label"]),
|
||||
"t-SNE Visualization (Colored by Sleep Stage)",
|
||||
os.path.join(out_dir, "tsne_by_label.pdf")
|
||||
)
|
||||
|
||||
create_scatter_plot(
|
||||
coordinates_2d,
|
||||
np.array(dataset["user_id"]),
|
||||
"t-SNE Visualization (Colored by User ID)",
|
||||
os.path.join(out_dir, "tsne_by_user.pdf")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(CLI)
|
||||
@@ -1,507 +0,0 @@
|
||||
"""
|
||||
User Classification from Time Series Embeddings
|
||||
|
||||
This module evaluates how well time series embeddings capture user-specific patterns
|
||||
by training a simple linear classifier to predict user identity from embeddings.
|
||||
|
||||
Motivation:
|
||||
If embeddings contain user-distinguishing information, a classifier should be able
|
||||
to predict which user a time series belongs to. High accuracy suggests that the
|
||||
embeddings capture individual characteristics.
|
||||
|
||||
Experimental Design:
|
||||
- Task: Multi-class classification
|
||||
- Model: Random Forest with ensemble of decision trees
|
||||
- Captures non-linear relationships in embedding space
|
||||
- Provides feature importance scores for interpretability
|
||||
- Robust to overfitting through bagging and random feature selection
|
||||
- Split Strategy:
|
||||
1. Session-based: Train on session 1, test on session 2
|
||||
2. Random: Standard train/test split
|
||||
|
||||
Session-based split is more challenging and realistic because:
|
||||
- Tests whether user patterns are stable across different recording sessions
|
||||
- Avoids data leakage from same-session samples in train and test
|
||||
|
||||
Random Forest Advantages over Logistic Regression:
|
||||
- Handles non-linear decision boundaries
|
||||
- Feature importance reveals which embedding dimensions matter most
|
||||
- No assumption about data distribution
|
||||
- Naturally handles multi-class classification
|
||||
|
||||
Output:
|
||||
- Classification metrics
|
||||
- Confusion matrix visualization
|
||||
|
||||
Usage:
|
||||
python simple_user_classifier.py \\
|
||||
--embeddings ./embeddings \\
|
||||
--out_dir ./results
|
||||
|
||||
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||
Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from datasets import load_from_disk, Dataset
|
||||
from fire import Fire
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
f1_score,
|
||||
classification_report,
|
||||
confusion_matrix,
|
||||
silhouette_score,
|
||||
)
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading
|
||||
# =============================================================================
|
||||
|
||||
def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
|
||||
if not os.path.isdir(embedding_path):
|
||||
raise FileNotFoundError(
|
||||
f"Dataset directory not found: {embedding_path}. "
|
||||
"Ensure this is a valid HuggingFace dataset directory."
|
||||
)
|
||||
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(embedding_path)
|
||||
|
||||
return dataset
|
||||
|
||||
# =============================================================================
|
||||
# Data Splitting
|
||||
# =============================================================================
|
||||
|
||||
def split_by_session(
|
||||
features: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
session_ids: np.ndarray,
|
||||
train_session: str = "1",
|
||||
test_session: str = "2",
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
|
||||
"""
|
||||
Split data by recording session for temporal generalization evaluation.
|
||||
|
||||
This split strategy tests whether learned patterns generalize across time.
|
||||
Training on session 1 and testing on session 2 simulates real-world deployment
|
||||
where models must work on future recordings.
|
||||
|
||||
Args:
|
||||
features: Feature matrix of shape (num_samples, num_features)
|
||||
labels: Label array of shape (num_samples,)
|
||||
session_ids: Session identifier for each sample
|
||||
train_session: Session ID to use for training (default: "1")
|
||||
test_session: Session ID to use for testing (default: "2")
|
||||
|
||||
Returns:
|
||||
Tuple of (X_train, X_test, y_train, y_test, split_description)
|
||||
"""
|
||||
# Create boolean masks for train and test sets
|
||||
train_mask = session_ids == train_session
|
||||
test_mask = session_ids == test_session
|
||||
|
||||
# Apply masks to create train/test splits
|
||||
X_train = features[train_mask]
|
||||
X_test = features[test_mask]
|
||||
y_train = labels[train_mask]
|
||||
y_test = labels[test_mask]
|
||||
|
||||
split_description = f"session({train_session}->train, {test_session}->test)"
|
||||
|
||||
return X_train, X_test, y_train, y_test, split_description
|
||||
|
||||
|
||||
def split_random(
|
||||
features: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
test_size: float = 0.2,
|
||||
random_state: int = 0,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
|
||||
"""
|
||||
Split data randomly with stratification for in-distribution evaluation.
|
||||
|
||||
Stratification ensures each class has proportional representation in
|
||||
both train and test sets, preventing class imbalance issues.
|
||||
|
||||
Args:
|
||||
features: Feature matrix of shape (num_samples, num_features)
|
||||
labels: Label array of shape (num_samples,)
|
||||
test_size: Fraction of data to use for testing (default: 0.2)
|
||||
random_state: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Tuple of (X_train, X_test, y_train, y_test, split_description)
|
||||
"""
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
features,
|
||||
labels,
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
stratify=labels,
|
||||
)
|
||||
|
||||
split_description = "random"
|
||||
|
||||
return X_train, X_test, y_train, y_test, split_description
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Training and Evaluation
|
||||
# =============================================================================
|
||||
|
||||
def create_classifier_pipeline(
|
||||
n_estimators: int = 200,
|
||||
max_depth: int = None,
|
||||
min_samples_split: int = 2,
|
||||
min_samples_leaf: int = 1,
|
||||
random_state: int = 0,
|
||||
) -> Pipeline:
|
||||
"""
|
||||
Create a scikit-learn pipeline for user classification using Random Forest.
|
||||
|
||||
Pipeline Architecture:
|
||||
----------------------
|
||||
1. StandardScaler: Z-score normalization of features
|
||||
- Centers features (mean=0) and scales to unit variance (std=1)
|
||||
- While Random Forest is scale-invariant, scaling helps with
|
||||
consistent feature importance interpretation
|
||||
|
||||
2. RandomForestClassifier: Ensemble of decision trees
|
||||
- Builds multiple decision trees on random subsets of data (bagging)
|
||||
- Each tree uses random subset of features at each split
|
||||
- Final prediction is majority vote across all trees
|
||||
- Provides feature_importances_ for interpretability
|
||||
|
||||
Random Forest Hyperparameters:
|
||||
------------------------------
|
||||
- n_estimators: Number of trees in the forest (more = better but slower)
|
||||
- max_depth: Maximum tree depth (None = expand until pure leaves)
|
||||
- min_samples_split: Minimum samples to split internal node
|
||||
- min_samples_leaf: Minimum samples required at leaf node
|
||||
|
||||
Args:
|
||||
n_estimators: Number of trees (default: 200)
|
||||
max_depth: Maximum depth of trees (default: None, fully grown)
|
||||
min_samples_split: Min samples for splitting (default: 2)
|
||||
min_samples_leaf: Min samples at leaf (default: 1)
|
||||
random_state: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Configured sklearn Pipeline ready for .fit() and .predict()
|
||||
"""
|
||||
pipeline = Pipeline([
|
||||
# Step 1: Feature normalization
|
||||
("scaler", StandardScaler(
|
||||
with_mean=True, # Subtract mean (center the data)
|
||||
with_std=True, # Divide by standard deviation (scale to unit variance)
|
||||
)),
|
||||
|
||||
# Step 2: Random Forest classification
|
||||
("classifier", RandomForestClassifier(
|
||||
n_estimators=n_estimators, # Number of trees in the forest
|
||||
max_depth=max_depth, # Maximum depth (None = unlimited)
|
||||
min_samples_split=min_samples_split,
|
||||
min_samples_leaf=min_samples_leaf,
|
||||
n_jobs=-1, # Use all CPU cores for parallel tree building
|
||||
random_state=random_state, # For reproducibility
|
||||
class_weight="balanced", # Handle class imbalance automatically
|
||||
oob_score=True, # Enable out-of-bag error estimation
|
||||
)),
|
||||
])
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def evaluate_classifier(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
) -> Tuple[float, float, str, np.ndarray, list]:
|
||||
"""
|
||||
Compute classification metrics and confusion matrix.
|
||||
|
||||
Metrics Computed:
|
||||
- Accuracy: Overall fraction of correct predictions
|
||||
- Macro F1: Average F1 across all classes
|
||||
- Per-class report: Precision, recall, F1 for each user
|
||||
- Confusion matrix: Detailed breakdown of predictions vs ground truth
|
||||
|
||||
Args:
|
||||
y_true: Ground truth labels
|
||||
y_pred: Predicted labels
|
||||
|
||||
Returns:
|
||||
Tuple of (accuracy, f1, classification_report, confusion_matrix, class_labels)
|
||||
"""
|
||||
# Compute scalar metrics
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
f1 = f1_score(y_true, y_pred, average='macro') # Multiclass: use macro-averaged F1
|
||||
|
||||
# Generate detailed per-class report
|
||||
report = classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
# Compute confusion matrix: Get all unique classes from both true and predicted labels
|
||||
class_labels = sorted(set(y_true) | set(y_pred))
|
||||
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
|
||||
|
||||
return accuracy, f1, report, cm, class_labels
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization
|
||||
# =============================================================================
|
||||
|
||||
def save_confusion_matrix_plot(
|
||||
confusion_mat: np.ndarray,
|
||||
class_labels: list,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a confusion matrix heatmap visualization.
|
||||
|
||||
The confusion matrix shows:
|
||||
- Rows: True class labels
|
||||
- Columns: Predicted class labels
|
||||
- Cell values: Count of samples with that (true, predicted) combination
|
||||
- Diagonal: Correct predictions
|
||||
- Off-diagonal: Misclassifications
|
||||
|
||||
Args:
|
||||
confusion_mat: Square matrix of shape (num_classes, num_classes)
|
||||
class_labels: List of class names for axis labels
|
||||
output_path: File path to save the plot (PDF format recommended)
|
||||
"""
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Plot heatmap using imshow
|
||||
im = ax.imshow(confusion_mat, aspect="auto", cmap="Blues")
|
||||
|
||||
# Add colorbar to show value scale
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
# Set labels and title
|
||||
ax.set_title("Confusion Matrix (User Classification)")
|
||||
ax.set_xlabel("Predicted User")
|
||||
ax.set_ylabel("True User")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format="pdf", bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved confusion matrix: {output_path}")
|
||||
|
||||
|
||||
def save_metrics_report(
|
||||
output_path: str,
|
||||
split_description: str,
|
||||
train_size: int,
|
||||
test_size: int,
|
||||
accuracy: float,
|
||||
macro_f1: float,
|
||||
classification_report_str: str,
|
||||
oob_score: float = None,
|
||||
silhouette: float = None,
|
||||
) -> None:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
# Write summary statistics
|
||||
f.write(f"split_used : {split_description}\n")
|
||||
f.write(f"train_size : {train_size}\n")
|
||||
f.write(f"test_size : {test_size}\n")
|
||||
|
||||
# Silhouette score (embedding quality metric)
|
||||
if silhouette is not None:
|
||||
f.write(f"silhouette : {silhouette:.4f}\n")
|
||||
|
||||
f.write(f"accuracy : {accuracy:.4f}\n")
|
||||
f.write(f"macro_f1 : {macro_f1:.4f}\n")
|
||||
|
||||
# Add OOB score if available (Random Forest specific)
|
||||
if oob_score is not None:
|
||||
f.write(f"oob_score : {oob_score:.4f}\n")
|
||||
|
||||
f.write("\n")
|
||||
|
||||
# Write detailed per-class report
|
||||
f.write(classification_report_str)
|
||||
|
||||
print(f"[DONE] Saved metrics report: {output_path}")
|
||||
|
||||
|
||||
def save_feature_importance_plot(
|
||||
feature_importances: np.ndarray,
|
||||
output_path: str,
|
||||
top_k: int = 50,
|
||||
) -> None:
|
||||
|
||||
# Get indices of top-k most important features
|
||||
top_indices = np.argsort(feature_importances)[::-1][:top_k]
|
||||
top_importances = feature_importances[top_indices]
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# Create horizontal bar chart
|
||||
y_positions = np.arange(len(top_indices))
|
||||
ax.barh(y_positions, top_importances, color="steelblue", alpha=0.8)
|
||||
|
||||
# Set labels
|
||||
ax.set_yticks(y_positions)
|
||||
ax.set_yticklabels([f"emb_{i}" for i in top_indices], fontsize=8)
|
||||
ax.invert_yaxis() # Highest importance at top
|
||||
|
||||
ax.set_xlabel("Feature Importance (Mean Decrease in Impurity)")
|
||||
ax.set_ylabel("Embedding Dimension")
|
||||
ax.set_title(f"Top {top_k} Most Important Embedding Dimensions")
|
||||
|
||||
# Save as vector PDF
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format="pdf", bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved feature importance plot: {output_path}")
|
||||
|
||||
|
||||
def save_feature_importance_csv(
|
||||
feature_importances: np.ndarray,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
importance_df = pd.DataFrame({
|
||||
"feature": [f"emb_{i}" for i in range(len(feature_importances))],
|
||||
"importance": feature_importances,
|
||||
})
|
||||
importance_df = importance_df.sort_values("importance", ascending=False)
|
||||
importance_df = importance_df.reset_index(drop=True)
|
||||
importance_df["rank"] = range(1, len(importance_df) + 1)
|
||||
importance_df = importance_df[["rank", "feature", "importance"]]
|
||||
importance_df.to_csv(output_path, index=False)
|
||||
|
||||
print(f"[DONE] Saved feature importance CSV: {output_path}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
def main(
|
||||
embeddings: str,
|
||||
out_dir: str,
|
||||
split_mode: str = "session",
|
||||
test_size: float = 0.2,
|
||||
n_estimators: int = 200,
|
||||
max_depth: int = None,
|
||||
labels: str = None,
|
||||
) -> None:
|
||||
|
||||
# Validate split_mode argument
|
||||
if split_mode not in ["session", "random"]:
|
||||
raise ValueError(f"Invalid split_mode: {split_mode}. Use 'session' or 'random'.")
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
print(f"[INFO] Loading embeddings from: {embeddings}")
|
||||
dataset = load_embeddings_with_metadata(embeddings)
|
||||
|
||||
# Filter by sleep stage labels if specified
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
dataset = dataset.filter(lambda x: str(x["label"]) in label_list)
|
||||
print(f"[INFO] Filtered to labels: {label_list}")
|
||||
|
||||
X = np.array(dataset["embedding"], dtype=np.float32)
|
||||
y = np.array([str(uid) for uid in dataset["user_id"]])
|
||||
session_ids = np.array([str(sid) for sid in dataset["session_id"]])
|
||||
|
||||
print(f"[INFO] Loaded {len(y)} samples with {X.shape[1]}-dimensional embeddings")
|
||||
print(f"[INFO] Number of unique users: {len(np.unique(y))}")
|
||||
|
||||
if split_mode == "session":
|
||||
has_session_1 = (session_ids == "1").sum() > 0
|
||||
has_session_2 = (session_ids == "2").sum() > 0
|
||||
|
||||
if has_session_1 and has_session_2:
|
||||
X_train, X_test, y_train, y_test, split_desc = split_by_session(
|
||||
X, y, session_ids
|
||||
)
|
||||
else:
|
||||
print("[WARN] Missing session data, falling back to random split")
|
||||
X_train, X_test, y_train, y_test, split_desc = split_random(
|
||||
X, y, test_size
|
||||
)
|
||||
split_desc = "random(fallback)"
|
||||
else:
|
||||
X_train, X_test, y_train, y_test, split_desc = split_random(
|
||||
X, y, test_size
|
||||
)
|
||||
|
||||
print(f"[INFO] Split: {split_desc}")
|
||||
print(f"[INFO] Train size: {len(y_train)}, Test size: {len(y_test)}")
|
||||
silhouette_avg = silhouette_score(X, y)
|
||||
print(f"[INFO] Silhouette Score (user clusters): {silhouette_avg:.4f}")
|
||||
print("[INFO] Training Random Forest classifier...")
|
||||
print(f"[INFO] Hyperparameters: n_estimators={n_estimators}, max_depth={max_depth}")
|
||||
|
||||
classifier = create_classifier_pipeline(
|
||||
n_estimators=n_estimators,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
classifier.fit(X_train, y_train)
|
||||
|
||||
y_pred = classifier.predict(X_test)
|
||||
rf_model = classifier.named_steps["classifier"]
|
||||
oob_score = rf_model.oob_score_ if hasattr(rf_model, "oob_score_") else None
|
||||
feature_importances = rf_model.feature_importances_
|
||||
|
||||
print("[INFO] Evaluating classifier performance...")
|
||||
accuracy, macro_f1, report, cm, classes = evaluate_classifier(y_test, y_pred)
|
||||
metrics_path = os.path.join(out_dir, "user_cls_metrics.txt")
|
||||
save_metrics_report(
|
||||
output_path=metrics_path,
|
||||
split_description=split_desc,
|
||||
train_size=len(y_train),
|
||||
test_size=len(y_test),
|
||||
accuracy=accuracy,
|
||||
macro_f1=macro_f1,
|
||||
classification_report_str=report,
|
||||
oob_score=oob_score,
|
||||
silhouette=silhouette_avg,
|
||||
)
|
||||
|
||||
confusion_path = os.path.join(out_dir, "user_cls_confusion.pdf")
|
||||
save_confusion_matrix_plot(cm, classes, confusion_path)
|
||||
|
||||
importance_plot_path = os.path.join(out_dir, "feature_importance.pdf")
|
||||
save_feature_importance_plot(feature_importances, importance_plot_path)
|
||||
|
||||
importance_csv_path = os.path.join(out_dir, "feature_importance.csv")
|
||||
save_feature_importance_csv(feature_importances, importance_csv_path)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("RANDOM FOREST CLASSIFICATION RESULTS")
|
||||
print("=" * 50)
|
||||
print(f"Split Strategy : {split_desc}")
|
||||
print(f"Silhouette : {silhouette_avg:.4f}")
|
||||
print(f"Accuracy : {accuracy:.4f}")
|
||||
print(f"Macro F1 : {macro_f1:.4f}")
|
||||
if oob_score is not None:
|
||||
print(f"OOB Score : {oob_score:.4f}")
|
||||
print(f"Top 5 Important Features:")
|
||||
top5_idx = np.argsort(feature_importances)[::-1][:5]
|
||||
for rank, idx in enumerate(top5_idx, 1):
|
||||
print(f" {rank}. emb_{idx}: {feature_importances[idx]:.4f}")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
@@ -1,504 +0,0 @@
|
||||
"""
|
||||
SBERT Time Series Embedding Extraction and Visualization Pipeline
|
||||
|
||||
This module provides functionality to
|
||||
1. Extract embeddings from time series features using SBERT (Sentence-BERT)
|
||||
2. Visualize embeddings using dimensionality reduction (t-SNE)
|
||||
|
||||
SBERT Overview:
|
||||
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
|
||||
and triplet network structures to derive semantically meaningful sentence
|
||||
embeddings. It's designed for semantic similarity tasks and produces
|
||||
fixed-size dense vector representations of text.
|
||||
|
||||
Key Features:
|
||||
- Semantic understanding: Captures meaning rather than just word presence
|
||||
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
|
||||
- Efficient: Optimized for sentence-level tasks
|
||||
|
||||
Embedding Strategy:
|
||||
We convert time series features into textual descriptions, then use SBERT
|
||||
to generate embeddings. This approach treats feature vectors as "sentences"
|
||||
where each feature-value pair is a "word" in the description.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Feature extraction from time series (statistical features)
|
||||
2. Textualization: Convert features to natural language description
|
||||
3. SBERT encoding: Generate 384-dimensional embeddings
|
||||
4. Aggregation: Store embeddings with metadata (user_id, session_id, label)
|
||||
|
||||
Usage:
|
||||
# Extract embeddings
|
||||
python gen_plot.py extract --data_root /path/to/data --out_dir ./embeddings
|
||||
|
||||
# Visualize with t-SNE
|
||||
python gen_plot.py plot --emb_dir ./embeddings --out_dir ./plots
|
||||
|
||||
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||
Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from datasets import load_from_disk, Dataset
|
||||
from fire import Fire
|
||||
from sklearn.manifold import TSNE
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
# EEG channel names used in Sleep-EDF dataset
|
||||
# Fpz-Cz: Frontal-Central electrode pair
|
||||
# Pz-Oz: Parietal-Occipital electrode pair
|
||||
EEG_CHANNEL_1 = "EEG Fpz-Cz"
|
||||
EEG_CHANNEL_2 = "EEG Pz-Oz"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Extractor Class
|
||||
# =============================================================================
|
||||
|
||||
class SBERT:
|
||||
"""
|
||||
Extracts fixed-dimensional embeddings from time series features using SBERT.
|
||||
|
||||
Uses Sentence-BERT to convert textualized feature descriptions into dense
|
||||
vector representations. The textualization process converts statistical
|
||||
features into natural language, which SBERT then encodes semantically.
|
||||
|
||||
Architecture:
|
||||
Time Series → Feature Extraction → Textualization → SBERT Encoder → Embedding
|
||||
|
||||
Processing Pipeline:
|
||||
1. Extract statistical features from time series (mean, std, etc.)
|
||||
2. Textualize: Convert features to natural language description
|
||||
3. SBERT encoding: Generate 384-dimensional semantic embeddings
|
||||
4. Output: Fixed-size embedding vector per sample
|
||||
|
||||
Model: all-MiniLM-L6-v2
|
||||
- Lightweight BERT variant optimized for sentence embeddings
|
||||
- 384-dimensional output embeddings
|
||||
- Fast inference with good semantic understanding
|
||||
|
||||
Attributes:
|
||||
model: SentenceTransformer instance for encoding textualized features
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the SBERT embedder with pre-trained model.
|
||||
|
||||
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
|
||||
optimized for sentence similarity tasks.
|
||||
"""
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
@staticmethod
|
||||
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
Discover all user/session directories under data_root.
|
||||
|
||||
Uses glob pattern matching for cleaner directory traversal.
|
||||
Expected structure: data_root/user_id/session_id/
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session subfolders
|
||||
|
||||
Returns:
|
||||
List of (user_id, session_id, session_path) tuples
|
||||
"""
|
||||
discovered_paths = []
|
||||
|
||||
# Use glob to find all session directories (2 levels deep)
|
||||
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
|
||||
if not os.path.isdir(session_path):
|
||||
continue
|
||||
|
||||
# Extract user_id and session_id from path
|
||||
session_id = os.path.basename(session_path)
|
||||
user_id = os.path.basename(os.path.dirname(session_path))
|
||||
|
||||
discovered_paths.append((user_id, session_id, session_path))
|
||||
|
||||
return discovered_paths
|
||||
|
||||
def textualize_sample(self, sample: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Convert a feature dictionary into a natural language description.
|
||||
|
||||
This textualization step is crucial for SBERT, which expects text input.
|
||||
The description provides context about the sleep stage and lists all
|
||||
extracted features in a structured format.
|
||||
|
||||
Args:
|
||||
sample: Dictionary mapping feature names to their values
|
||||
Example: {"mean": 0.5, "std": 0.2, "max": 1.0, ...}
|
||||
|
||||
Returns:
|
||||
Natural language string describing the features
|
||||
"""
|
||||
sentence = (
|
||||
"While sleeping (one of the stages: W, N1, N2, N3, REM), "
|
||||
"I have sensor data features measured from two channels: EEG Fpz-Cz and EEG Pz-Oz.\n"
|
||||
"The features are as follows: \n"
|
||||
)
|
||||
for k, v in sample.items():
|
||||
sentence += f" - {k}: {self.format_feature(v)}\n"
|
||||
sentence.strip()
|
||||
return sentence
|
||||
|
||||
def format_feature(self, value: Any) -> str:
|
||||
"""
|
||||
Format a feature value for textual representation.
|
||||
|
||||
Floats are rounded to 2 decimal places for readability,
|
||||
other types are converted to strings.
|
||||
|
||||
Args:
|
||||
value: Feature value (float, int, or other type)
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if isinstance(value, float):
|
||||
return f"{value:.2f}"
|
||||
return str(value)
|
||||
|
||||
def compute_embedding(self, batch: Dict[str, Any]) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding vectors from feature batches using SBERT.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Extract feature dictionaries from batch
|
||||
2. Textualize each sample's features into natural language
|
||||
3. Encode textual descriptions using SBERT
|
||||
4. Return fixed-size embedding vectors
|
||||
|
||||
Args:
|
||||
batch: HuggingFace dataset batch from slicing (dataset[start:end])
|
||||
Format: {"features": [{"mean": 0.5, "std": 0.2, ...}, ...]}
|
||||
|
||||
Returns:
|
||||
Embedding array of shape (batch_size, embedding_dim)
|
||||
For all-MiniLM-L6-v2: (batch_size, 384)
|
||||
"""
|
||||
# Extract feature dictionaries from batch
|
||||
samples = batch["features"]
|
||||
|
||||
# Convert each sample's features to text
|
||||
text_samples = []
|
||||
for sample in samples:
|
||||
text_samples.append(self.textualize_sample(sample))
|
||||
|
||||
# Encode text descriptions using SBERT
|
||||
# Returns numpy array of shape (batch_size, 384)
|
||||
embeddings = self.model.encode(text_samples)
|
||||
|
||||
return embeddings
|
||||
|
||||
def extract_embeddings(
|
||||
self,
|
||||
data_root: str,
|
||||
batch_size: int = 32,
|
||||
label: str = "N1"
|
||||
) -> Dataset:
|
||||
"""
|
||||
Extract embeddings from all sessions under the data root directory.
|
||||
|
||||
Iterates through all user/session combinations, processes features
|
||||
in batches, and aggregates results with metadata. Filters by sleep
|
||||
stage label to focus on specific sleep stages.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
batch_size: Number of samples to process together.
|
||||
Larger = faster but more memory.
|
||||
32 is a good balance for most systems.
|
||||
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM")
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with columns:
|
||||
- user_id, session_id, idx, label (metadata)
|
||||
- embedding (384-dim vector from all-MiniLM-L6-v2)
|
||||
"""
|
||||
session_paths = self.discover_session_paths(data_root)
|
||||
print(f"[INFO] Discovered {len(session_paths)} sessions")
|
||||
|
||||
all_embeddings = []
|
||||
all_user_ids = []
|
||||
all_session_ids = []
|
||||
all_idxs = []
|
||||
all_labels = []
|
||||
|
||||
for user_id, session_id, session_path in session_paths:
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(session_path)
|
||||
# Shuffle dataset for randomness
|
||||
dataset = dataset.shuffle(seed=0)
|
||||
# Filter by sleep stage label
|
||||
dataset = dataset.filter(lambda x: x["label"] == label)
|
||||
num_samples = len(dataset)
|
||||
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
|
||||
|
||||
# Process in batches to manage memory
|
||||
for batch_start in range(0, num_samples, batch_size):
|
||||
batch_end = min(batch_start + batch_size, num_samples)
|
||||
|
||||
# Slice dataset to get batch
|
||||
batch = dataset[batch_start:batch_end]
|
||||
|
||||
# Compute embeddings
|
||||
embeddings = self.compute_embedding(batch)
|
||||
|
||||
# Collect embeddings and metadata
|
||||
for i in range(embeddings.shape[0]):
|
||||
all_embeddings.append(embeddings[i].tolist())
|
||||
all_user_ids.append(str(batch["user_id"][i]))
|
||||
all_session_ids.append(str(batch["session_id"][i]))
|
||||
all_idxs.append(int(batch["idx"][i]))
|
||||
all_labels.append(str(batch["label"][i]))
|
||||
|
||||
# Create HuggingFace Dataset
|
||||
result_dataset = Dataset.from_dict({
|
||||
"user_id": all_user_ids,
|
||||
"session_id": all_session_ids,
|
||||
"idx": all_idxs,
|
||||
"label": all_labels,
|
||||
"embedding": all_embeddings,
|
||||
})
|
||||
|
||||
print(f"[INFO] Total samples: {len(result_dataset)}")
|
||||
return result_dataset
|
||||
|
||||
def save_embeddings(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
output_dir: str
|
||||
) -> None:
|
||||
"""
|
||||
Save embeddings dataset to disk in HuggingFace format.
|
||||
|
||||
Args:
|
||||
dataset: HuggingFace Dataset containing embeddings and metadata
|
||||
output_dir: Directory path to save the dataset
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
dataset.save_to_disk(output_dir)
|
||||
|
||||
print(f"[DONE] Saved embeddings dataset: {output_dir}")
|
||||
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
|
||||
|
||||
@staticmethod
|
||||
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||
"""
|
||||
Load saved embeddings dataset from disk.
|
||||
|
||||
Args:
|
||||
embedding_dir: Directory path containing the saved HuggingFace dataset
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with embeddings and metadata
|
||||
"""
|
||||
dataset = load_from_disk(embedding_dir)
|
||||
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization Functions
|
||||
# =============================================================================
|
||||
|
||||
def reduce_to_2d_tsne(
|
||||
embeddings: np.ndarray,
|
||||
perplexity: float = 30.0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Reduce high-dimensional embeddings to 2D using t-SNE.
|
||||
|
||||
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
|
||||
dimensionality reduction technique that preserves local structure.
|
||||
Points that are similar in high dimensions stay close in 2D.
|
||||
|
||||
Args:
|
||||
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
|
||||
perplexity: t-SNE perplexity parameter (typically 5-50).
|
||||
Higher values consider more neighbors, creating smoother layouts.
|
||||
Rule of thumb: perplexity ~ sqrt(num_samples)
|
||||
|
||||
Returns:
|
||||
2D coordinates of shape (num_samples, 2)
|
||||
"""
|
||||
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
|
||||
|
||||
tsne = TSNE(
|
||||
n_components=2,
|
||||
random_state=0, # For reproducibility
|
||||
perplexity=perplexity,
|
||||
max_iter=1000, # Usually sufficient for convergence
|
||||
init='random',
|
||||
learning_rate='auto', # Let sklearn choose optimal learning rate
|
||||
)
|
||||
|
||||
return tsne.fit_transform(embeddings)
|
||||
|
||||
|
||||
def create_scatter_plot(
|
||||
coordinates: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
title: str,
|
||||
output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot with categorical coloring.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
labels: Category labels for each point (string array)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF vector format recommended)
|
||||
"""
|
||||
# Get unique labels for legend
|
||||
unique_labels = sorted(set(labels))
|
||||
|
||||
# Select colormap based on number of categories
|
||||
# tab10: 10 distinct colors, tab20: 20 distinct colors
|
||||
colormap = plt.cm.tab10 if len(unique_labels) <= 10 else plt.cm.tab20
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Plot each category separately for proper legend
|
||||
for idx, label in enumerate(unique_labels):
|
||||
mask = labels == label
|
||||
ax.scatter(
|
||||
coordinates[mask, 0],
|
||||
coordinates[mask, 1],
|
||||
c=[colormap(idx % 20)],
|
||||
s=15,
|
||||
label=label,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# Configure plot appearance
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
ax.legend(loc='best', markerscale=2)
|
||||
|
||||
# Save figure as vector PDF (scalable, ideal for publications)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
class CLI:
|
||||
"""
|
||||
Command-line interface for SBERT embedding extraction and visualization.
|
||||
|
||||
Provides two main commands:
|
||||
- extract: Generate embeddings from time series features
|
||||
- plot: Visualize embeddings with t-SNE
|
||||
"""
|
||||
|
||||
def extract(
|
||||
self,
|
||||
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
|
||||
out_dir: str = "./embeddings/REM",
|
||||
batch_size: int = 32,
|
||||
label: str = "REM"
|
||||
) -> None:
|
||||
"""
|
||||
Extract embeddings from time series data.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
out_dir: Output directory for HuggingFace dataset
|
||||
batch_size: Batch size for inference (default: 32)
|
||||
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM')
|
||||
"""
|
||||
embedder = SBERT()
|
||||
dataset = embedder.extract_embeddings(data_root, batch_size, label)
|
||||
embedder.save_embeddings(dataset, out_dir)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
emb_dir: str = "./embeddings/W",
|
||||
out_dir: str = "./plots/W",
|
||||
perplexity: float = 10.0,
|
||||
users: str = None,
|
||||
num_users: int = 0,
|
||||
labels: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize embeddings with t-SNE.
|
||||
|
||||
Args:
|
||||
emb_dir: Directory containing the HuggingFace embeddings dataset
|
||||
out_dir: Output directory for visualization plots (PDF)
|
||||
perplexity: t-SNE perplexity parameter (default: 10.0)
|
||||
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||
num_users: Include only first N users, 0 = all (default: 0)
|
||||
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||
"""
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Load saved embeddings dataset
|
||||
dataset = SBERT.load_embeddings(emb_dir)
|
||||
|
||||
# Apply user filtering
|
||||
if users:
|
||||
user_list = [u.strip() for u in users.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
|
||||
print(f"[INFO] Filtered to users: {user_list}")
|
||||
|
||||
elif num_users > 0:
|
||||
all_users = sorted(set(dataset["user_id"]))
|
||||
selected_users = all_users[:num_users]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||
|
||||
# Filter by sleep stage labels
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["label"] in label_list)
|
||||
print(f"[INFO] Filtered to labels: {label_list}")
|
||||
|
||||
print(f"[INFO] Total samples: {len(dataset)}")
|
||||
|
||||
# Extract embeddings as numpy array for t-SNE
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
|
||||
# Reduce to 2D with t-SNE
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||
|
||||
# Generate visualizations
|
||||
create_scatter_plot(
|
||||
coordinates_2d,
|
||||
np.array(dataset["label"]),
|
||||
"t-SNE Visualization (Colored by Sleep Stage)",
|
||||
os.path.join(out_dir, "tsne_by_label.pdf")
|
||||
)
|
||||
|
||||
create_scatter_plot(
|
||||
coordinates_2d,
|
||||
np.array(dataset["user_id"]),
|
||||
"t-SNE Visualization (Colored by User ID)",
|
||||
os.path.join(out_dir, "tsne_by_user.pdf")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(CLI)
|
||||
@@ -1,529 +0,0 @@
|
||||
"""
|
||||
SBERT Embedding Extraction and Visualization with Metadata (Age and Sex)
|
||||
|
||||
This module provides functionality to:
|
||||
1. Extract SBERT embeddings from time series features
|
||||
2. Visualize embeddings colored by subject metadata (age and sex) instead of user IDs
|
||||
|
||||
Features:
|
||||
1. Extract embeddings from time series features using SBERT
|
||||
2. Load embeddings from HuggingFace dataset
|
||||
3. Load subject metadata from XLS file
|
||||
4. Map user IDs to age and sex information
|
||||
5. Visualize embeddings with t-SNE colored by age (continuous) and sex (categorical)
|
||||
|
||||
Usage:
|
||||
# Extract embeddings for all labels
|
||||
python gen_plot_metadata.py extract --data_root /path/to/data --out_dir ./embeddings/all_labels
|
||||
|
||||
# Extract embeddings for a single label
|
||||
python gen_plot_metadata.py extract --data_root /path/to/data --out_dir ./embeddings/W --label W
|
||||
|
||||
# Visualize by age from single label directory
|
||||
python gen_plot_metadata.py plot --emb_dir ./embeddings/W --out_dir ./plots/W --color_by age
|
||||
|
||||
# Visualize by age from all label directories (W, REM, N1, N2, N3)
|
||||
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by age
|
||||
|
||||
# Visualize by sex from all labels
|
||||
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by sex
|
||||
|
||||
# Create both age and sex plots from all labels
|
||||
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by both
|
||||
|
||||
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||
Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from datasets import load_from_disk, Dataset, concatenate_datasets
|
||||
from fire import Fire
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
from gen_plot import SBERT, reduce_to_2d_tsne
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
SUBJECT_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/SC-subjects.xls"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metadata Loading
|
||||
# =============================================================================
|
||||
|
||||
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Load subject metadata from XLS file.
|
||||
|
||||
The XLS file contains subject information including:
|
||||
- subject: Subject ID (e.g., "SC4001", "SC4002")
|
||||
- age: Age of the subject
|
||||
- sex (F=1): Sex (1 = Female, 0 = Male)
|
||||
|
||||
Args:
|
||||
subject_path: Path to the SC-subjects.xls file
|
||||
|
||||
Returns:
|
||||
Dictionary mapping subject IDs to metadata dictionaries
|
||||
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
|
||||
"""
|
||||
df = pd.read_excel(subject_path)
|
||||
subject_info = {}
|
||||
|
||||
for index, row in df.iterrows():
|
||||
subject_id = str(row["subject"]).strip()
|
||||
subject_info[subject_id] = {
|
||||
"age": int(row["age"]) if pd.notna(row["age"]) else None,
|
||||
"sex": int(row["sex (F=1)"]) if pd.notna(row["sex (F=1)"]) else None,
|
||||
}
|
||||
|
||||
return subject_info
|
||||
|
||||
|
||||
def map_user_ids_to_metadata(
|
||||
user_ids: np.ndarray,
|
||||
subject_metadata: Dict[str, Dict[str, Any]]
|
||||
) -> tuple:
|
||||
"""
|
||||
Map user IDs from dataset to age and sex metadata.
|
||||
|
||||
User IDs in the dataset are typically 2-digit codes (e.g., "40", "41").
|
||||
Subject IDs in the metadata file are typically 4-character codes (e.g., "SC40", "SC41").
|
||||
We need to match them appropriately.
|
||||
|
||||
Args:
|
||||
user_ids: Array of user IDs from the dataset
|
||||
subject_metadata: Dictionary of subject metadata loaded from XLS file
|
||||
|
||||
Returns:
|
||||
Tuple of (ages, sexes) as numpy arrays
|
||||
Missing values are set to None
|
||||
"""
|
||||
ages = []
|
||||
sexes = []
|
||||
|
||||
for user_id in user_ids:
|
||||
age = subject_metadata[user_id]["age"]
|
||||
sex = subject_metadata[user_id]["sex"]
|
||||
ages.append(age)
|
||||
sexes.append(sex)
|
||||
|
||||
return np.array(ages), np.array(sexes)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_scatter_plot_by_age(
|
||||
coordinates: np.ndarray,
|
||||
ages: np.ndarray,
|
||||
title: str,
|
||||
output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot colored by age (continuous colormap).
|
||||
|
||||
Uses a continuous colormap (viridis) to show age distribution.
|
||||
Ages are mapped to colors on a gradient scale.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
ages: Age values for each point (array, may contain None values)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF vector format recommended)
|
||||
"""
|
||||
# Filter out points with missing age data
|
||||
# Convert None values to NaN for proper numpy handling
|
||||
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
|
||||
valid_mask = ~np.isnan(ages_float)
|
||||
valid_coords = coordinates[valid_mask]
|
||||
valid_ages = ages_float[valid_mask]
|
||||
|
||||
if len(valid_ages) == 0:
|
||||
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
|
||||
return
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Create scatter plot with continuous colormap
|
||||
scatter = ax.scatter(
|
||||
valid_coords[:, 0],
|
||||
valid_coords[:, 1],
|
||||
c=valid_ages,
|
||||
cmap='viridis', # Continuous colormap for granular age visualization
|
||||
s=15,
|
||||
alpha=0.7,
|
||||
edgecolors='none',
|
||||
)
|
||||
|
||||
# Add colorbar
|
||||
cbar = plt.colorbar(scatter, ax=ax)
|
||||
cbar.set_label('Age (years)', rotation=270, labelpad=20)
|
||||
|
||||
# Configure plot appearance
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
|
||||
# Save figure as vector PDF
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
|
||||
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
|
||||
|
||||
|
||||
def create_scatter_plot_by_sex(
|
||||
coordinates: np.ndarray,
|
||||
sexes: np.ndarray,
|
||||
title: str,
|
||||
output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot colored by sex (categorical).
|
||||
|
||||
Uses discrete colors for different sex categories.
|
||||
Sex encoding: 1 = Female, 0 = Male
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
sexes: Sex values for each point (array, may contain None values)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF vector format recommended)
|
||||
"""
|
||||
# Filter out points with missing sex data
|
||||
# Convert None values to NaN for proper numpy handling
|
||||
sexes_float = np.array([float(s) if s is not None else np.nan for s in sexes])
|
||||
valid_mask = ~np.isnan(sexes_float)
|
||||
valid_coords = coordinates[valid_mask]
|
||||
valid_sexes = sexes_float[valid_mask].astype(int)
|
||||
|
||||
if len(valid_sexes) == 0:
|
||||
print(f"[WARN] No valid sex data found. Skipping plot: {output_path}")
|
||||
return
|
||||
|
||||
# Map sex codes to labels
|
||||
sex_labels = {0: "Male", 1: "Female"}
|
||||
unique_sexes = sorted(set(valid_sexes))
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Plot each category separately for proper legend
|
||||
colors = ['steelblue', 'coral'] # Blue for Male, Coral for Female
|
||||
for idx, sex_code in enumerate(unique_sexes):
|
||||
mask = valid_sexes == sex_code
|
||||
ax.scatter(
|
||||
valid_coords[mask, 0],
|
||||
valid_coords[mask, 1],
|
||||
c=colors[sex_code % len(colors)],
|
||||
s=15,
|
||||
label=sex_labels.get(sex_code, f"Sex {sex_code}"),
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# Configure plot appearance
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
ax.legend(loc='best', markerscale=2)
|
||||
|
||||
# Save figure as vector PDF
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
sex_counts = {sex_labels.get(s, f"Sex {s}"): (valid_sexes == s).sum() for s in unique_sexes}
|
||||
print(f"[INFO] Sex distribution: {sex_counts}")
|
||||
print(f"[INFO] Points with valid sex: {len(valid_sexes)}/{len(sexes)}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Extraction Utilities
|
||||
# =============================================================================
|
||||
|
||||
def extract_embeddings_for_all_labels(
|
||||
data_root: str,
|
||||
out_dir: str,
|
||||
batch_size: int = 32
|
||||
) -> Dataset:
|
||||
"""
|
||||
Extract embeddings for all sleep stage labels and combine them.
|
||||
|
||||
Extracts embeddings for each label (W, REM, N1, N2, N3) separately,
|
||||
then concatenates them into a single dataset.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
out_dir: Output directory for the combined HuggingFace dataset
|
||||
batch_size: Batch size for inference (default: 32)
|
||||
|
||||
Returns:
|
||||
Combined HuggingFace Dataset with all labels
|
||||
"""
|
||||
embedder = SBERT()
|
||||
all_labels = ["W", "REM", "N1", "N2", "N3"]
|
||||
|
||||
datasets = []
|
||||
for label in all_labels:
|
||||
print(f"\n[INFO] Extracting embeddings for label: {label}")
|
||||
dataset = embedder.extract_embeddings(data_root, batch_size, label)
|
||||
print(f"[INFO] Extracted {len(dataset)} samples for label {label}")
|
||||
datasets.append(dataset)
|
||||
|
||||
# Concatenate all datasets
|
||||
if len(datasets) == 1:
|
||||
combined_dataset = datasets[0]
|
||||
else:
|
||||
combined_dataset = concatenate_datasets(datasets)
|
||||
|
||||
print(f"\n[INFO] Combined dataset: {len(combined_dataset)} total samples")
|
||||
|
||||
# Print label distribution
|
||||
if "label" in combined_dataset.column_names:
|
||||
label_counts = {}
|
||||
for label in combined_dataset["label"]:
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
print(f"[INFO] Label distribution: {label_counts}")
|
||||
|
||||
# Save combined dataset
|
||||
embedder.save_embeddings(combined_dataset, out_dir)
|
||||
|
||||
return combined_dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading Utilities
|
||||
# =============================================================================
|
||||
|
||||
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
|
||||
"""
|
||||
Load embeddings from all label subdirectories and concatenate them.
|
||||
|
||||
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
|
||||
and loads embeddings from each, then concatenates them into a single dataset.
|
||||
|
||||
Args:
|
||||
embeddings_root: Root directory containing label subdirectories
|
||||
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
|
||||
|
||||
Returns:
|
||||
Concatenated HuggingFace Dataset with all labels combined
|
||||
"""
|
||||
# Discover all label subdirectories
|
||||
label_dirs = []
|
||||
for item in os.listdir(embeddings_root):
|
||||
item_path = os.path.join(embeddings_root, item)
|
||||
if os.path.isdir(item_path):
|
||||
# Check if it's a valid HuggingFace dataset directory
|
||||
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
|
||||
label_dirs.append((item, item_path))
|
||||
|
||||
if len(label_dirs) == 0:
|
||||
raise ValueError(
|
||||
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
|
||||
)
|
||||
|
||||
label_dirs.sort() # Sort for consistent ordering
|
||||
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
|
||||
|
||||
# Load datasets from each label directory
|
||||
datasets = []
|
||||
for label_name, label_path in label_dirs:
|
||||
print(f"[INFO] Loading embeddings from: {label_path}")
|
||||
dataset = load_from_disk(label_path)
|
||||
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
|
||||
datasets.append(dataset)
|
||||
|
||||
# Concatenate all datasets
|
||||
if len(datasets) == 1:
|
||||
combined_dataset = datasets[0]
|
||||
else:
|
||||
combined_dataset = concatenate_datasets(datasets)
|
||||
|
||||
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
|
||||
|
||||
# Print label distribution
|
||||
if "label" in combined_dataset.column_names:
|
||||
label_counts = {}
|
||||
for label in combined_dataset["label"]:
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
print(f"[INFO] Label distribution: {label_counts}")
|
||||
|
||||
return combined_dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
class CLI:
|
||||
"""
|
||||
Command-line interface for SBERT embedding extraction and visualization with metadata.
|
||||
|
||||
Provides:
|
||||
- extract: Generate embeddings from time series features
|
||||
- plot: Visualize embeddings colored by age or sex instead of user ID
|
||||
"""
|
||||
|
||||
def extract(
|
||||
self,
|
||||
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
|
||||
out_dir: str = "./embeddings/all_labels",
|
||||
batch_size: int = 32,
|
||||
label: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Extract embeddings from time series features using SBERT.
|
||||
|
||||
Can extract for a single label or all labels at once.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
out_dir: Output directory for HuggingFace dataset
|
||||
batch_size: Batch size for inference (default: 32)
|
||||
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
|
||||
If None or 'all', extracts for all labels (default: None for all labels)
|
||||
"""
|
||||
embedder = SBERT()
|
||||
|
||||
if label is None or label == "all":
|
||||
# Extract for all labels
|
||||
print(f"[INFO] Extracting embeddings for all labels")
|
||||
extract_embeddings_for_all_labels(data_root, out_dir, batch_size)
|
||||
else:
|
||||
# Extract for single label
|
||||
print(f"[INFO] Extracting embeddings for label: {label}")
|
||||
dataset = embedder.extract_embeddings(data_root, batch_size, label)
|
||||
embedder.save_embeddings(dataset, out_dir)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
emb_dir: str = None,
|
||||
embeddings_root: str = "./embeddings",
|
||||
out_dir: str = "./plots/all_labels",
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
perplexity: float = 10.0,
|
||||
color_by: str = "age",
|
||||
users: str = None,
|
||||
num_users: int = 0,
|
||||
labels: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize embeddings with t-SNE, colored by age or sex.
|
||||
|
||||
Can load from either a single label directory or all label directories.
|
||||
|
||||
Args:
|
||||
emb_dir: Single directory containing the HuggingFace embeddings dataset
|
||||
(e.g., "./embeddings/W"). If provided, only this directory is used.
|
||||
embeddings_root: Root directory containing label subdirectories
|
||||
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
|
||||
Used only if emb_dir is not provided.
|
||||
out_dir: Output directory for visualization plots (PDF)
|
||||
subject_path: Path to SC-subjects.xls file with metadata
|
||||
perplexity: t-SNE perplexity parameter (default: 10.0)
|
||||
color_by: What to color by - 'age', 'sex', or 'both' (default: 'age')
|
||||
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||
num_users: Include only first N users, 0 = all (default: 0)
|
||||
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||
This filters the already-loaded data, not which directories to load.
|
||||
"""
|
||||
# Validate color_by argument
|
||||
if color_by not in ["age", "sex", "both", "all"]:
|
||||
raise ValueError(f"Invalid color_by: {color_by}. Use 'age', 'sex', or 'both'.")
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Load embeddings: either from single directory or all label directories
|
||||
if emb_dir is not None:
|
||||
# Load from single directory
|
||||
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
|
||||
dataset = SBERT.load_embeddings(emb_dir)
|
||||
else:
|
||||
# Load from all label directories
|
||||
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
|
||||
dataset = load_embeddings_from_all_labels(embeddings_root)
|
||||
|
||||
# Apply user filtering
|
||||
if users:
|
||||
user_list = [u.strip() for u in users.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
|
||||
print(f"[INFO] Filtered to users: {user_list}")
|
||||
|
||||
elif num_users > 0:
|
||||
all_users = sorted(set(dataset["user_id"]))
|
||||
selected_users = all_users[:num_users]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||
|
||||
# Filter by sleep stage labels
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["label"] in label_list)
|
||||
print(f"[INFO] Filtered to labels: {label_list}")
|
||||
|
||||
print(f"[INFO] Total samples: {len(dataset)}")
|
||||
|
||||
# Extract embeddings as numpy array for t-SNE
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
|
||||
# Load subject metadata
|
||||
print(f"[INFO] Loading subject metadata from: {subject_path}")
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
|
||||
|
||||
# Map user IDs to metadata
|
||||
user_ids = np.array([str(uid) for uid in dataset["user_id"]])
|
||||
user_ids_ = [str(int(uid)) for uid in user_ids]
|
||||
ages, sexes = map_user_ids_to_metadata(user_ids_, subject_metadata)
|
||||
|
||||
# Reduce to 2D with t-SNE
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||
|
||||
# Generate visualizations based on color_by parameter
|
||||
if color_by == "age":
|
||||
create_scatter_plot_by_age(
|
||||
coordinates_2d,
|
||||
ages,
|
||||
"t-SNE Visualization (Colored by Age)",
|
||||
os.path.join(out_dir, "tsne_by_age.pdf")
|
||||
)
|
||||
elif color_by == "sex":
|
||||
create_scatter_plot_by_sex(
|
||||
coordinates_2d,
|
||||
sexes,
|
||||
"t-SNE Visualization (Colored by Sex)",
|
||||
os.path.join(out_dir, "tsne_by_sex.pdf")
|
||||
)
|
||||
elif color_by == "both" or color_by == "all":
|
||||
# Create both plots
|
||||
create_scatter_plot_by_age(
|
||||
coordinates_2d,
|
||||
ages,
|
||||
"t-SNE Visualization (Colored by Age)",
|
||||
os.path.join(out_dir, "tsne_by_age.pdf")
|
||||
)
|
||||
create_scatter_plot_by_sex(
|
||||
coordinates_2d,
|
||||
sexes,
|
||||
"t-SNE Visualization (Colored by Sex)",
|
||||
os.path.join(out_dir, "tsne_by_sex.pdf")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(CLI)
|
||||
@@ -1,527 +0,0 @@
|
||||
"""
|
||||
User Classification from SBERT Embeddings
|
||||
|
||||
This module evaluates how well SBERT embeddings capture user-specific patterns
|
||||
by training a simple linear classifier to predict user identity from embeddings.
|
||||
|
||||
Motivation:
|
||||
If embeddings contain user-distinguishing information, a classifier should be able
|
||||
to predict which user a time series belongs to. High accuracy suggests that the
|
||||
embeddings capture individual characteristics.
|
||||
|
||||
Experimental Design:
|
||||
- Task: Multi-class classification
|
||||
- Model: Random Forest with ensemble of decision trees
|
||||
- Captures non-linear relationships in embedding space
|
||||
- Provides feature importance scores for interpretability
|
||||
- Robust to overfitting through bagging and random feature selection
|
||||
- Split Strategy:
|
||||
1. Session-based: Train on session 1, test on session 2
|
||||
2. Random: Standard train/test split
|
||||
|
||||
Session-based split is more challenging and realistic because:
|
||||
- Tests whether user patterns are stable across different recording sessions
|
||||
- Avoids data leakage from same-session samples in train and test
|
||||
|
||||
Random Forest Advantages over Logistic Regression:
|
||||
- Handles non-linear decision boundaries
|
||||
- Feature importance reveals which embedding dimensions matter most
|
||||
- No assumption about data distribution
|
||||
- Naturally handles multi-class classification
|
||||
|
||||
Output:
|
||||
- Classification metrics
|
||||
- Confusion matrix visualization
|
||||
|
||||
Usage:
|
||||
python simple_user_classifier.py \\
|
||||
--embeddings ./embeddings \\
|
||||
--out_dir ./results
|
||||
|
||||
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||
Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from datasets import load_from_disk, Dataset
|
||||
from fire import Fire
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
f1_score,
|
||||
classification_report,
|
||||
confusion_matrix,
|
||||
silhouette_score,
|
||||
)
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading
|
||||
# =============================================================================
|
||||
|
||||
def load_embeddings_with_metadata(embedding_path: str) -> Dataset:
|
||||
if not os.path.isdir(embedding_path):
|
||||
raise FileNotFoundError(
|
||||
f"Dataset directory not found: {embedding_path}. "
|
||||
"Ensure this is a valid HuggingFace dataset directory."
|
||||
)
|
||||
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(embedding_path)
|
||||
|
||||
return dataset
|
||||
|
||||
# =============================================================================
|
||||
# Data Splitting
|
||||
# =============================================================================
|
||||
|
||||
def split_by_session(
|
||||
features: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
session_ids: np.ndarray,
|
||||
train_session: str = "1",
|
||||
test_session: str = "2",
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
|
||||
"""
|
||||
Split data by recording session for temporal generalization evaluation.
|
||||
|
||||
This split strategy tests whether learned patterns generalize across time.
|
||||
Training on session 1 and testing on session 2 simulates real-world deployment
|
||||
where models must work on future recordings.
|
||||
|
||||
Args:
|
||||
features: Feature matrix of shape (num_samples, num_features)
|
||||
labels: Label array of shape (num_samples,)
|
||||
session_ids: Session identifier for each sample
|
||||
train_session: Session ID to use for training (default: "1")
|
||||
test_session: Session ID to use for testing (default: "2")
|
||||
|
||||
Returns:
|
||||
Tuple of (X_train, X_test, y_train, y_test, split_description)
|
||||
"""
|
||||
# Create boolean masks for train and test sets
|
||||
train_mask = session_ids == train_session
|
||||
test_mask = session_ids == test_session
|
||||
|
||||
# Apply masks to create train/test splits
|
||||
X_train = features[train_mask]
|
||||
X_test = features[test_mask]
|
||||
y_train = labels[train_mask]
|
||||
y_test = labels[test_mask]
|
||||
|
||||
split_description = f"session({train_session}->train, {test_session}->test)"
|
||||
|
||||
return X_train, X_test, y_train, y_test, split_description
|
||||
|
||||
|
||||
def split_random(
|
||||
features: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
test_size: float = 0.2,
|
||||
random_state: int = 0,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
|
||||
"""
|
||||
Split data randomly with stratification for in-distribution evaluation.
|
||||
|
||||
Stratification ensures each class has proportional representation in
|
||||
both train and test sets, preventing class imbalance issues.
|
||||
|
||||
Args:
|
||||
features: Feature matrix of shape (num_samples, num_features)
|
||||
labels: Label array of shape (num_samples,)
|
||||
test_size: Fraction of data to use for testing (default: 0.2)
|
||||
random_state: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Tuple of (X_train, X_test, y_train, y_test, split_description)
|
||||
"""
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
features,
|
||||
labels,
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
stratify=labels,
|
||||
)
|
||||
|
||||
split_description = "random"
|
||||
|
||||
return X_train, X_test, y_train, y_test, split_description
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Training and Evaluation
|
||||
# =============================================================================
|
||||
|
||||
def create_classifier_pipeline(
|
||||
n_estimators: int = 200,
|
||||
max_depth: int = None,
|
||||
min_samples_split: int = 2,
|
||||
min_samples_leaf: int = 1,
|
||||
random_state: int = 0,
|
||||
) -> Pipeline:
|
||||
"""
|
||||
Create a scikit-learn pipeline for user classification using Random Forest.
|
||||
|
||||
Pipeline Architecture:
|
||||
----------------------
|
||||
1. StandardScaler: Z-score normalization of features
|
||||
- Centers features (mean=0) and scales to unit variance (std=1)
|
||||
- While Random Forest is scale-invariant, scaling helps with
|
||||
consistent feature importance interpretation
|
||||
|
||||
2. RandomForestClassifier: Ensemble of decision trees
|
||||
- Builds multiple decision trees on random subsets of data (bagging)
|
||||
- Each tree uses random subset of features at each split
|
||||
- Final prediction is majority vote across all trees
|
||||
- Provides feature_importances_ for interpretability
|
||||
|
||||
Random Forest Hyperparameters:
|
||||
------------------------------
|
||||
- n_estimators: Number of trees in the forest (more = better but slower)
|
||||
- max_depth: Maximum tree depth (None = expand until pure leaves)
|
||||
- min_samples_split: Minimum samples to split internal node
|
||||
- min_samples_leaf: Minimum samples required at leaf node
|
||||
|
||||
Args:
|
||||
n_estimators: Number of trees (default: 200)
|
||||
max_depth: Maximum depth of trees (default: None, fully grown)
|
||||
min_samples_split: Min samples for splitting (default: 2)
|
||||
min_samples_leaf: Min samples at leaf (default: 1)
|
||||
random_state: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
Configured sklearn Pipeline ready for .fit() and .predict()
|
||||
"""
|
||||
pipeline = Pipeline([
|
||||
# Step 1: Feature normalization
|
||||
("scaler", StandardScaler(
|
||||
with_mean=True, # Subtract mean (center the data)
|
||||
with_std=True, # Divide by standard deviation (scale to unit variance)
|
||||
)),
|
||||
|
||||
# Step 2: Random Forest classification
|
||||
("classifier", RandomForestClassifier(
|
||||
n_estimators=n_estimators, # Number of trees in the forest
|
||||
max_depth=max_depth, # Maximum depth (None = unlimited)
|
||||
min_samples_split=min_samples_split,
|
||||
min_samples_leaf=min_samples_leaf,
|
||||
n_jobs=-1, # Use all CPU cores for parallel tree building
|
||||
random_state=random_state, # For reproducibility
|
||||
class_weight="balanced", # Handle class imbalance automatically
|
||||
oob_score=True, # Enable out-of-bag error estimation
|
||||
)),
|
||||
])
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def evaluate_classifier(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
) -> Tuple[float, float, str, np.ndarray, list]:
|
||||
"""
|
||||
Compute classification metrics and confusion matrix.
|
||||
|
||||
Metrics Computed:
|
||||
- Accuracy: Overall fraction of correct predictions
|
||||
- Macro F1: Average F1 across all classes
|
||||
- Per-class report: Precision, recall, F1 for each user
|
||||
- Confusion matrix: Detailed breakdown of predictions vs ground truth
|
||||
|
||||
Args:
|
||||
y_true: Ground truth labels
|
||||
y_pred: Predicted labels
|
||||
|
||||
Returns:
|
||||
Tuple of (accuracy, f1, classification_report, confusion_matrix, class_labels)
|
||||
"""
|
||||
# Compute scalar metrics
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
f1 = f1_score(y_true, y_pred, average='macro') # Multiclass: use macro-averaged F1
|
||||
|
||||
# Generate detailed per-class report
|
||||
report = classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
# Compute confusion matrix: Get all unique classes from both true and predicted labels
|
||||
class_labels = sorted(set(y_true) | set(y_pred))
|
||||
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
|
||||
|
||||
return accuracy, f1, report, cm, class_labels
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization
|
||||
# =============================================================================
|
||||
|
||||
def save_confusion_matrix_plot(
|
||||
confusion_mat: np.ndarray,
|
||||
class_labels: list,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a confusion matrix heatmap visualization.
|
||||
|
||||
The confusion matrix shows:
|
||||
- Rows: True class labels
|
||||
- Columns: Predicted class labels
|
||||
- Cell values: Count of samples with that (true, predicted) combination
|
||||
- Diagonal: Correct predictions
|
||||
- Off-diagonal: Misclassifications
|
||||
|
||||
Args:
|
||||
confusion_mat: Square matrix of shape (num_classes, num_classes)
|
||||
class_labels: List of class names for axis labels
|
||||
output_path: File path to save the plot (PDF format recommended)
|
||||
"""
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Plot heatmap using imshow
|
||||
im = ax.imshow(confusion_mat, aspect="auto", cmap="Blues")
|
||||
|
||||
# Add colorbar to show value scale
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
# Set labels and title
|
||||
ax.set_title("Confusion Matrix (User Classification)")
|
||||
ax.set_xlabel("Predicted User")
|
||||
ax.set_ylabel("True User")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format="pdf", bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved confusion matrix: {output_path}")
|
||||
|
||||
|
||||
def save_metrics_report(
|
||||
output_path: str,
|
||||
split_description: str,
|
||||
train_size: int,
|
||||
test_size: int,
|
||||
accuracy: float,
|
||||
macro_f1: float,
|
||||
classification_report_str: str,
|
||||
oob_score: float = None,
|
||||
silhouette: float = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save classification metrics to a text file.
|
||||
|
||||
Writes summary statistics (split strategy, sizes, scores) and detailed
|
||||
per-class classification report to a text file.
|
||||
|
||||
Args:
|
||||
output_path: File path to save the metrics report
|
||||
split_description: Description of train/test split strategy
|
||||
train_size: Number of training samples
|
||||
test_size: Number of test samples
|
||||
accuracy: Overall classification accuracy
|
||||
macro_f1: Macro-averaged F1 score
|
||||
classification_report_str: Detailed per-class report string
|
||||
oob_score: Out-of-bag score from Random Forest (optional)
|
||||
silhouette: Silhouette score for embedding quality (optional)
|
||||
"""
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
# Write summary statistics
|
||||
f.write(f"split_used : {split_description}\n")
|
||||
f.write(f"train_size : {train_size}\n")
|
||||
f.write(f"test_size : {test_size}\n")
|
||||
|
||||
# Silhouette score (embedding quality metric)
|
||||
if silhouette is not None:
|
||||
f.write(f"silhouette : {silhouette:.4f}\n")
|
||||
|
||||
f.write(f"accuracy : {accuracy:.4f}\n")
|
||||
f.write(f"macro_f1 : {macro_f1:.4f}\n")
|
||||
|
||||
# Add OOB score if available (Random Forest specific)
|
||||
if oob_score is not None:
|
||||
f.write(f"oob_score : {oob_score:.4f}\n")
|
||||
|
||||
f.write("\n")
|
||||
|
||||
# Write detailed per-class report
|
||||
f.write(classification_report_str)
|
||||
|
||||
print(f"[DONE] Saved metrics report: {output_path}")
|
||||
|
||||
|
||||
def save_feature_importance_plot(
|
||||
feature_importances: np.ndarray,
|
||||
output_path: str,
|
||||
top_k: int = 50,
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a horizontal bar chart of top-k most important features.
|
||||
|
||||
Visualizes which embedding dimensions contribute most to user classification.
|
||||
Higher importance indicates that dimension better distinguishes between users.
|
||||
|
||||
Args:
|
||||
feature_importances: Array of feature importance scores from Random Forest
|
||||
output_path: File path to save the plot (PDF format recommended)
|
||||
top_k: Number of top features to display (default: 50)
|
||||
"""
|
||||
# Get indices of top-k most important features
|
||||
top_indices = np.argsort(feature_importances)[::-1][:top_k]
|
||||
top_importances = feature_importances[top_indices]
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# Create horizontal bar chart
|
||||
y_positions = np.arange(len(top_indices))
|
||||
ax.barh(y_positions, top_importances, color="steelblue", alpha=0.8)
|
||||
|
||||
# Set labels
|
||||
ax.set_yticks(y_positions)
|
||||
ax.set_yticklabels([f"emb_{i}" for i in top_indices], fontsize=8)
|
||||
ax.invert_yaxis() # Highest importance at top
|
||||
|
||||
ax.set_xlabel("Feature Importance (Mean Decrease in Impurity)")
|
||||
ax.set_ylabel("Embedding Dimension")
|
||||
ax.set_title(f"Top {top_k} Most Important Embedding Dimensions")
|
||||
|
||||
# Save as vector PDF
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format="pdf", bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved feature importance plot: {output_path}")
|
||||
|
||||
|
||||
def save_feature_importance_csv(
|
||||
feature_importances: np.ndarray,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Save feature importance scores to CSV file with ranking.
|
||||
|
||||
Creates a CSV with columns: rank, feature, importance
|
||||
Features are sorted by importance in descending order.
|
||||
|
||||
Args:
|
||||
feature_importances: Array of feature importance scores from Random Forest
|
||||
output_path: File path to save the CSV file
|
||||
"""
|
||||
importance_df = pd.DataFrame({
|
||||
"feature": [f"emb_{i}" for i in range(len(feature_importances))],
|
||||
"importance": feature_importances,
|
||||
})
|
||||
importance_df = importance_df.sort_values("importance", ascending=False)
|
||||
importance_df = importance_df.reset_index(drop=True)
|
||||
importance_df["rank"] = range(1, len(importance_df) + 1)
|
||||
importance_df = importance_df[["rank", "feature", "importance"]]
|
||||
importance_df.to_csv(output_path, index=False)
|
||||
|
||||
print(f"[DONE] Saved feature importance CSV: {output_path}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
def main(
|
||||
embeddings: str = "./embeddings/REM",
|
||||
out_dir: str = "./results/REM",
|
||||
test_size: float = 0.2,
|
||||
n_estimators: int = 200,
|
||||
max_depth: int = None,
|
||||
) -> None:
|
||||
# Create output directory
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
print(f"[INFO] Loading embeddings from: {embeddings}")
|
||||
dataset = load_embeddings_with_metadata(embeddings)
|
||||
|
||||
X = np.array(dataset["embedding"], dtype=np.float32)
|
||||
y = np.array([str(uid) for uid in dataset["user_id"]])
|
||||
session_ids = np.array([str(sid) for sid in dataset["session_id"]])
|
||||
|
||||
print(f"[INFO] Loaded {len(y)} samples with {X.shape[1]}-dimensional embeddings")
|
||||
print(f"[INFO] Number of unique users: {len(np.unique(y))}")
|
||||
|
||||
has_session_1 = (session_ids == "1").sum() > 0
|
||||
has_session_2 = (session_ids == "2").sum() > 0
|
||||
|
||||
if has_session_1 and has_session_2:
|
||||
X_train, X_test, y_train, y_test, split_desc = split_by_session(
|
||||
X, y, session_ids
|
||||
)
|
||||
else:
|
||||
print("[WARN] Missing session data, falling back to random split")
|
||||
X_train, X_test, y_train, y_test, split_desc = split_random(
|
||||
X, y, test_size
|
||||
)
|
||||
split_desc = "random(fallback)"
|
||||
|
||||
print(f"[INFO] Split: {split_desc}")
|
||||
print(f"[INFO] Train size: {len(y_train)}, Test size: {len(y_test)}")
|
||||
silhouette_avg = silhouette_score(X, y)
|
||||
print(f"[INFO] Silhouette Score (user clusters): {silhouette_avg:.4f}")
|
||||
print("[INFO] Training Random Forest classifier...")
|
||||
print(f"[INFO] Hyperparameters: n_estimators={n_estimators}, max_depth={max_depth}")
|
||||
|
||||
classifier = create_classifier_pipeline(
|
||||
n_estimators=n_estimators,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
classifier.fit(X_train, y_train)
|
||||
|
||||
y_pred = classifier.predict(X_test)
|
||||
rf_model = classifier.named_steps["classifier"]
|
||||
oob_score = rf_model.oob_score_ if hasattr(rf_model, "oob_score_") else None
|
||||
feature_importances = rf_model.feature_importances_
|
||||
|
||||
print("[INFO] Evaluating classifier performance...")
|
||||
accuracy, macro_f1, report, cm, classes = evaluate_classifier(y_test, y_pred)
|
||||
metrics_path = os.path.join(out_dir, "user_cls_metrics.txt")
|
||||
save_metrics_report(
|
||||
output_path=metrics_path,
|
||||
split_description=split_desc,
|
||||
train_size=len(y_train),
|
||||
test_size=len(y_test),
|
||||
accuracy=accuracy,
|
||||
macro_f1=macro_f1,
|
||||
classification_report_str=report,
|
||||
oob_score=oob_score,
|
||||
silhouette=silhouette_avg,
|
||||
)
|
||||
|
||||
confusion_path = os.path.join(out_dir, "user_cls_confusion.pdf")
|
||||
save_confusion_matrix_plot(cm, classes, confusion_path)
|
||||
|
||||
importance_plot_path = os.path.join(out_dir, "feature_importance.pdf")
|
||||
save_feature_importance_plot(feature_importances, importance_plot_path)
|
||||
|
||||
importance_csv_path = os.path.join(out_dir, "feature_importance.csv")
|
||||
save_feature_importance_csv(feature_importances, importance_csv_path)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("RANDOM FOREST CLASSIFICATION RESULTS")
|
||||
print("=" * 50)
|
||||
print(f"Split Strategy : {split_desc}")
|
||||
print(f"Silhouette : {silhouette_avg:.4f}")
|
||||
print(f"Accuracy : {accuracy:.4f}")
|
||||
print(f"Macro F1 : {macro_f1:.4f}")
|
||||
if oob_score is not None:
|
||||
print(f"OOB Score : {oob_score:.4f}")
|
||||
print(f"Top 5 Important Features:")
|
||||
top5_idx = np.argsort(feature_importances)[::-1][:5]
|
||||
for rank, idx in enumerate(top5_idx, 1):
|
||||
print(f" {rank}. emb_{idx}: {feature_importances[idx]:.4f}")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
|
||||
@@ -1,756 +0,0 @@
|
||||
"""
|
||||
SBERT Metadata-Based Embedding Extraction and Visualization Pipeline
|
||||
|
||||
This module provides functionality to
|
||||
1. Extract embeddings from user metadata (age and sex) using SBERT (Sentence-BERT)
|
||||
2. Visualize embeddings using dimensionality reduction (t-SNE) colored by age
|
||||
|
||||
SBERT Overview:
|
||||
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
|
||||
and triplet network structures to derive semantically meaningful sentence
|
||||
embeddings. It's designed for semantic similarity tasks and produces
|
||||
fixed-size dense vector representations of text.
|
||||
|
||||
Key Features:
|
||||
- Semantic understanding: Captures meaning rather than just word presence
|
||||
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
|
||||
- Efficient: Optimized for sentence-level tasks
|
||||
|
||||
Embedding Strategy:
|
||||
Instead of using time series features, we create textual descriptions based on
|
||||
user metadata (age and sex). This approach allows us to capture user-level
|
||||
characteristics in the embedding space.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Load user metadata from XLS file (age, sex)
|
||||
2. Textualization: Convert metadata to natural language description
|
||||
3. SBERT encoding: Generate 384-dimensional embeddings
|
||||
4. Visualization: t-SNE with continuous age coloring
|
||||
|
||||
Usage:
|
||||
# Extract embeddings from metadata for all labels
|
||||
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/all_labels
|
||||
|
||||
# Extract embeddings for a single label
|
||||
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/REM --label REM
|
||||
|
||||
# Visualize with t-SNE from all label directories (colored by age)
|
||||
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by age
|
||||
|
||||
# Visualize by sex from all labels
|
||||
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by sex
|
||||
|
||||
# Create both age and sex plots from all labels
|
||||
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by both
|
||||
|
||||
# Visualize from a single label directory
|
||||
python gen_plot.py plot --emb_dir ./embeddings/REM --out_dir ./plots/REM --color_by age
|
||||
|
||||
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||
Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from datasets import load_from_disk, Dataset, concatenate_datasets
|
||||
from fire import Fire
|
||||
from sklearn.manifold import TSNE
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
SUBJECT_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/SC-subjects.xls"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metadata Loading
|
||||
# =============================================================================
|
||||
|
||||
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Load subject metadata from XLS file.
|
||||
|
||||
The XLS file contains subject information including:
|
||||
- subject: Subject ID (e.g., "SC4001", "SC4002")
|
||||
- age: Age of the subject
|
||||
- sex (F=1): Sex (1 = Female, 0 = Male)
|
||||
|
||||
Args:
|
||||
subject_path: Path to the SC-subjects.xls file
|
||||
|
||||
Returns:
|
||||
Dictionary mapping subject IDs to metadata dictionaries
|
||||
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
|
||||
"""
|
||||
df = pd.read_excel(subject_path)
|
||||
subject_info = {}
|
||||
|
||||
for index, row in df.iterrows():
|
||||
subject_id = str(row["subject"]).strip()
|
||||
subject_info[subject_id] = {
|
||||
"age": int(row["age"]) if pd.notna(row["age"]) else None,
|
||||
"sex": int(row["sex (F=1)"]) if pd.notna(row["sex (F=1)"]) else None,
|
||||
}
|
||||
|
||||
return subject_info
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Extractor Class
|
||||
# =============================================================================
|
||||
|
||||
class SBERT_Metadata:
|
||||
"""
|
||||
Extracts fixed-dimensional embeddings from user metadata using SBERT.
|
||||
|
||||
Uses Sentence-BERT to convert textualized metadata descriptions into dense
|
||||
vector representations. The textualization process converts age and sex
|
||||
information into natural language, which SBERT then encodes semantically.
|
||||
|
||||
Architecture:
|
||||
User Metadata → Textualization → SBERT Encoder → Embedding
|
||||
|
||||
Processing Pipeline:
|
||||
1. Load user metadata (age, sex) from XLS file
|
||||
2. Textualize: Convert metadata to natural language description
|
||||
3. SBERT encoding: Generate 384-dimensional semantic embeddings
|
||||
4. Output: Fixed-size embedding vector per user
|
||||
|
||||
Model: all-MiniLM-L6-v2
|
||||
- Lightweight BERT variant optimized for sentence embeddings
|
||||
- 384-dimensional output embeddings
|
||||
- Fast inference with good semantic understanding
|
||||
|
||||
Attributes:
|
||||
model: SentenceTransformer instance for encoding textualized metadata
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the SBERT embedder with pre-trained model.
|
||||
|
||||
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
|
||||
optimized for sentence similarity tasks.
|
||||
"""
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
@staticmethod
|
||||
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
Discover all user/session directories under data_root.
|
||||
|
||||
Uses glob pattern matching for cleaner directory traversal.
|
||||
Expected structure: data_root/user_id/session_id/
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session subfolders
|
||||
|
||||
Returns:
|
||||
List of (user_id, session_id, session_path) tuples
|
||||
"""
|
||||
discovered_paths = []
|
||||
|
||||
# Use glob to find all session directories (2 levels deep)
|
||||
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
|
||||
if not os.path.isdir(session_path):
|
||||
continue
|
||||
|
||||
# Extract user_id and session_id from path
|
||||
session_id = os.path.basename(session_path)
|
||||
user_id = os.path.basename(os.path.dirname(session_path))
|
||||
|
||||
discovered_paths.append((user_id, session_id, session_path))
|
||||
|
||||
return discovered_paths
|
||||
|
||||
def textualize_metadata(self, age: Optional[int], sex: Optional[int]) -> str:
|
||||
"""
|
||||
Convert user metadata (age and sex) into a natural language description.
|
||||
|
||||
This textualization step is crucial for SBERT, which expects text input.
|
||||
The description provides user demographic information in a structured format.
|
||||
|
||||
Args:
|
||||
age: Age of the user (integer, may be None)
|
||||
sex: Sex of the user (0 = Male, 1 = Female, may be None)
|
||||
|
||||
Returns:
|
||||
Natural language string describing the user metadata
|
||||
"""
|
||||
# Map sex code to text
|
||||
if sex is not None:
|
||||
sex_text = "Female" if sex == 1 else "Male"
|
||||
else:
|
||||
sex_text = "Unknown"
|
||||
|
||||
# Create sentence from metadata
|
||||
if age is not None:
|
||||
sentence = f"This is the information of the user, age: {age}, sex: {sex_text}."
|
||||
else:
|
||||
sentence = f"This is the information of the user, age: unknown, sex: {sex_text}."
|
||||
|
||||
return sentence
|
||||
|
||||
def compute_embedding_from_metadata(
|
||||
self,
|
||||
ages: List[Optional[int]],
|
||||
sexes: List[Optional[int]]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding vectors from metadata using SBERT.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Textualize each user's metadata into natural language
|
||||
2. Encode textual descriptions using SBERT
|
||||
3. Return fixed-size embedding vectors
|
||||
|
||||
Args:
|
||||
ages: List of age values (may contain None)
|
||||
sexes: List of sex values (0 = Male, 1 = Female, may contain None)
|
||||
|
||||
Returns:
|
||||
Embedding array of shape (batch_size, embedding_dim)
|
||||
For all-MiniLM-L6-v2: (batch_size, 384)
|
||||
"""
|
||||
# Convert metadata to text sentences
|
||||
text_samples = []
|
||||
for age, sex in zip(ages, sexes):
|
||||
text_samples.append(self.textualize_metadata(age, sex))
|
||||
|
||||
# Encode text descriptions using SBERT
|
||||
# Returns numpy array of shape (batch_size, 384)
|
||||
embeddings = self.model.encode(text_samples)
|
||||
|
||||
return embeddings
|
||||
|
||||
def extract_embeddings(
|
||||
self,
|
||||
data_root: str,
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
batch_size: int = 32,
|
||||
label: Optional[str] = None
|
||||
) -> Dataset:
|
||||
"""
|
||||
Extract embeddings from user metadata for all sessions.
|
||||
|
||||
Iterates through all user/session combinations, loads metadata for each user,
|
||||
generates embeddings from metadata sentences, and aggregates results.
|
||||
Can process a single label or all labels.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
subject_path: Path to SC-subjects.xls file with metadata
|
||||
batch_size: Number of samples to process together (for batching embeddings)
|
||||
Larger = faster but more memory.
|
||||
32 is a good balance for most systems.
|
||||
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM").
|
||||
If None or "all", processes all labels.
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with columns:
|
||||
- user_id, session_id, idx, label (metadata from original data)
|
||||
- embedding (384-dim vector from all-MiniLM-L6-v2 based on metadata)
|
||||
"""
|
||||
# Load subject metadata
|
||||
print(f"[INFO] Loading subject metadata from: {subject_path}")
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
|
||||
|
||||
session_paths = self.discover_session_paths(data_root)
|
||||
print(f"[INFO] Discovered {len(session_paths)} sessions")
|
||||
|
||||
all_embeddings = []
|
||||
all_user_ids = []
|
||||
all_session_ids = []
|
||||
all_idxs = []
|
||||
all_labels = []
|
||||
all_ages = []
|
||||
all_sexes = []
|
||||
|
||||
# Collect metadata for all samples first
|
||||
for user_id, session_id, session_path in session_paths:
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(session_path)
|
||||
# Shuffle dataset for randomness
|
||||
dataset = dataset.shuffle(seed=0)
|
||||
# Filter by sleep stage label if specified
|
||||
if label is not None and label != "all":
|
||||
dataset = dataset.filter(lambda x: x["label"] == label)
|
||||
num_samples = len(dataset)
|
||||
|
||||
if num_samples == 0:
|
||||
continue
|
||||
|
||||
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
|
||||
|
||||
# Get metadata for this user
|
||||
# Convert user_id to string format that matches metadata keys
|
||||
user_id_str = str(int(user_id))
|
||||
try:
|
||||
age = subject_metadata[user_id_str]["age"]
|
||||
sex = subject_metadata[user_id_str]["sex"]
|
||||
except KeyError:
|
||||
age = None
|
||||
sex = None
|
||||
print(f"[WARN] No metadata found for user_id: {user_id_str}")
|
||||
|
||||
# Collect all samples for this user/session
|
||||
for i in range(num_samples):
|
||||
all_user_ids.append(str(dataset["user_id"][i]))
|
||||
all_session_ids.append(str(dataset["session_id"][i]))
|
||||
all_idxs.append(int(dataset["idx"][i]))
|
||||
all_labels.append(str(dataset["label"][i]))
|
||||
all_ages.append(age)
|
||||
all_sexes.append(sex)
|
||||
|
||||
# Generate embeddings from metadata in batches
|
||||
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
|
||||
for batch_start in range(0, len(all_ages), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(all_ages))
|
||||
|
||||
batch_ages = all_ages[batch_start:batch_end]
|
||||
batch_sexes = all_sexes[batch_start:batch_end]
|
||||
|
||||
# Compute embeddings from metadata
|
||||
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
|
||||
|
||||
# Collect embeddings
|
||||
for i in range(embeddings.shape[0]):
|
||||
all_embeddings.append(embeddings[i].tolist())
|
||||
|
||||
# Create HuggingFace Dataset
|
||||
result_dataset = Dataset.from_dict({
|
||||
"user_id": all_user_ids,
|
||||
"session_id": all_session_ids,
|
||||
"idx": all_idxs,
|
||||
"label": all_labels,
|
||||
"age": all_ages,
|
||||
"sex": all_sexes,
|
||||
"embedding": all_embeddings,
|
||||
})
|
||||
|
||||
print(f"[INFO] Total samples: {len(result_dataset)}")
|
||||
|
||||
# Print label distribution
|
||||
if "label" in result_dataset.column_names:
|
||||
label_counts = {}
|
||||
for lbl in result_dataset["label"]:
|
||||
label_counts[lbl] = label_counts.get(lbl, 0) + 1
|
||||
print(f"[INFO] Label distribution: {label_counts}")
|
||||
|
||||
return result_dataset
|
||||
|
||||
def save_embeddings(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
output_dir: str
|
||||
) -> None:
|
||||
"""
|
||||
Save embeddings dataset to disk in HuggingFace format.
|
||||
|
||||
Args:
|
||||
dataset: HuggingFace Dataset containing embeddings and metadata
|
||||
output_dir: Directory path to save the dataset
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
dataset.save_to_disk(output_dir)
|
||||
|
||||
print(f"[DONE] Saved embeddings dataset: {output_dir}")
|
||||
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
|
||||
|
||||
@staticmethod
|
||||
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||
"""
|
||||
Load saved embeddings dataset from disk.
|
||||
|
||||
Args:
|
||||
embedding_dir: Directory path containing the saved HuggingFace dataset
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with embeddings and metadata
|
||||
"""
|
||||
dataset = load_from_disk(embedding_dir)
|
||||
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading Utilities
|
||||
# =============================================================================
|
||||
|
||||
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
|
||||
"""
|
||||
Load embeddings from all label subdirectories and concatenate them.
|
||||
|
||||
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
|
||||
and loads embeddings from each, then concatenates them into a single dataset.
|
||||
|
||||
Args:
|
||||
embeddings_root: Root directory containing label subdirectories
|
||||
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
|
||||
|
||||
Returns:
|
||||
Concatenated HuggingFace Dataset with all labels combined
|
||||
"""
|
||||
# Discover all label subdirectories
|
||||
label_dirs = []
|
||||
for item in os.listdir(embeddings_root):
|
||||
item_path = os.path.join(embeddings_root, item)
|
||||
if os.path.isdir(item_path):
|
||||
# Check if it's a valid HuggingFace dataset directory
|
||||
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
|
||||
label_dirs.append((item, item_path))
|
||||
|
||||
if len(label_dirs) == 0:
|
||||
raise ValueError(
|
||||
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
|
||||
)
|
||||
|
||||
label_dirs.sort() # Sort for consistent ordering
|
||||
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
|
||||
|
||||
# Load datasets from each label directory
|
||||
datasets = []
|
||||
for label_name, label_path in label_dirs:
|
||||
print(f"[INFO] Loading embeddings from: {label_path}")
|
||||
dataset = load_from_disk(label_path)
|
||||
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
|
||||
datasets.append(dataset)
|
||||
|
||||
# Concatenate all datasets
|
||||
if len(datasets) == 1:
|
||||
combined_dataset = datasets[0]
|
||||
else:
|
||||
combined_dataset = concatenate_datasets(datasets)
|
||||
|
||||
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
|
||||
|
||||
# Print label distribution
|
||||
if "label" in combined_dataset.column_names:
|
||||
label_counts = {}
|
||||
for label in combined_dataset["label"]:
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
print(f"[INFO] Label distribution: {label_counts}")
|
||||
|
||||
return combined_dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization Functions
|
||||
# =============================================================================
|
||||
|
||||
def reduce_to_2d_tsne(
|
||||
embeddings: np.ndarray,
|
||||
perplexity: float = 30.0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Reduce high-dimensional embeddings to 2D using t-SNE.
|
||||
|
||||
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
|
||||
dimensionality reduction technique that preserves local structure.
|
||||
Points that are similar in high dimensions stay close in 2D.
|
||||
|
||||
Args:
|
||||
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
|
||||
perplexity: t-SNE perplexity parameter (typically 5-50).
|
||||
Higher values consider more neighbors, creating smoother layouts.
|
||||
Rule of thumb: perplexity ~ sqrt(num_samples)
|
||||
|
||||
Returns:
|
||||
2D coordinates of shape (num_samples, 2)
|
||||
"""
|
||||
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
|
||||
|
||||
tsne = TSNE(
|
||||
n_components=2,
|
||||
random_state=0, # For reproducibility
|
||||
perplexity=perplexity,
|
||||
max_iter=1000, # Usually sufficient for convergence
|
||||
init='random',
|
||||
learning_rate='auto', # Let sklearn choose optimal learning rate
|
||||
)
|
||||
|
||||
return tsne.fit_transform(embeddings)
|
||||
|
||||
|
||||
def create_scatter_plot_by_age(
|
||||
coordinates: np.ndarray,
|
||||
ages: np.ndarray,
|
||||
title: str,
|
||||
output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot colored by age (continuous colormap).
|
||||
|
||||
Uses a continuous colormap (viridis) to show age distribution.
|
||||
Ages are mapped to colors on a gradient scale.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
ages: Age values for each point (array, may contain None values)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF vector format recommended)
|
||||
"""
|
||||
# Filter out points with missing age data
|
||||
# Convert None values to NaN for proper numpy handling
|
||||
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
|
||||
valid_mask = ~np.isnan(ages_float)
|
||||
valid_coords = coordinates[valid_mask]
|
||||
valid_ages = ages_float[valid_mask]
|
||||
|
||||
if len(valid_ages) == 0:
|
||||
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
|
||||
return
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Create scatter plot with continuous colormap
|
||||
scatter = ax.scatter(
|
||||
valid_coords[:, 0],
|
||||
valid_coords[:, 1],
|
||||
c=valid_ages,
|
||||
cmap='viridis', # Continuous colormap for granular age visualization
|
||||
s=15,
|
||||
alpha=0.7,
|
||||
edgecolors='none',
|
||||
)
|
||||
|
||||
# Add colorbar
|
||||
cbar = plt.colorbar(scatter, ax=ax)
|
||||
cbar.set_label('Age (years)', rotation=270, labelpad=20)
|
||||
|
||||
# Configure plot appearance
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
|
||||
# Save figure as vector PDF
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
|
||||
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
|
||||
|
||||
|
||||
def create_scatter_plot_by_sex(
|
||||
coordinates: np.ndarray,
|
||||
sexes: np.ndarray,
|
||||
title: str,
|
||||
output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot colored by sex (categorical).
|
||||
|
||||
Uses discrete colors for different sex categories.
|
||||
Sex encoding: 1 = Female, 0 = Male
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
sexes: Sex values for each point (array, may contain None values)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF vector format recommended)
|
||||
"""
|
||||
# Filter out points with missing sex data
|
||||
# Convert None values to NaN for proper numpy handling
|
||||
sexes_float = np.array([float(s) if s is not None else np.nan for s in sexes])
|
||||
valid_mask = ~np.isnan(sexes_float)
|
||||
valid_coords = coordinates[valid_mask]
|
||||
valid_sexes = sexes_float[valid_mask].astype(int)
|
||||
|
||||
if len(valid_sexes) == 0:
|
||||
print(f"[WARN] No valid sex data found. Skipping plot: {output_path}")
|
||||
return
|
||||
|
||||
# Map sex codes to labels
|
||||
sex_labels = {0: "Male", 1: "Female"}
|
||||
unique_sexes = sorted(set(valid_sexes))
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Plot each category separately for proper legend
|
||||
colors = ['steelblue', 'coral'] # Blue for Male, Coral for Female
|
||||
for idx, sex_code in enumerate(unique_sexes):
|
||||
mask = valid_sexes == sex_code
|
||||
ax.scatter(
|
||||
valid_coords[mask, 0],
|
||||
valid_coords[mask, 1],
|
||||
c=colors[sex_code % len(colors)],
|
||||
s=15,
|
||||
label=sex_labels.get(sex_code, f"Sex {sex_code}"),
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# Configure plot appearance
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
ax.legend(loc='best', markerscale=2)
|
||||
|
||||
# Save figure as vector PDF
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
sex_counts = {sex_labels.get(s, f"Sex {s}"): (valid_sexes == s).sum() for s in unique_sexes}
|
||||
print(f"[INFO] Sex distribution: {sex_counts}")
|
||||
print(f"[INFO] Points with valid sex: {len(valid_sexes)}/{len(sexes)}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
class CLI:
|
||||
"""
|
||||
Command-line interface for SBERT metadata-based embedding extraction and visualization.
|
||||
|
||||
Provides two main commands:
|
||||
- extract: Generate embeddings from user metadata (age, sex)
|
||||
- plot: Visualize embeddings with t-SNE colored by age or sex
|
||||
"""
|
||||
|
||||
def extract(
|
||||
self,
|
||||
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
out_dir: str = "./embeddings/all_labels",
|
||||
batch_size: int = 32,
|
||||
label: str = None
|
||||
) -> None:
|
||||
"""
|
||||
Extract embeddings from user metadata.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
subject_path: Path to SC-subjects.xls file with metadata
|
||||
out_dir: Output directory for HuggingFace dataset
|
||||
batch_size: Batch size for inference (default: 32)
|
||||
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
|
||||
If None or 'all', processes all labels (default: None for all labels)
|
||||
"""
|
||||
embedder = SBERT_Metadata()
|
||||
dataset = embedder.extract_embeddings(data_root, subject_path, batch_size, label)
|
||||
embedder.save_embeddings(dataset, out_dir)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
emb_dir: str = None,
|
||||
embeddings_root: str = "./embeddings",
|
||||
out_dir: str = "./plots/all_labels",
|
||||
perplexity: float = 10.0,
|
||||
color_by: str = "age",
|
||||
users: str = None,
|
||||
num_users: int = 0,
|
||||
labels: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize embeddings with t-SNE, colored by age or sex.
|
||||
|
||||
Can load from either a single label directory or all label directories.
|
||||
|
||||
Args:
|
||||
emb_dir: Single directory containing the HuggingFace embeddings dataset
|
||||
(e.g., "./embeddings/REM"). If provided, only this directory is used.
|
||||
embeddings_root: Root directory containing label subdirectories
|
||||
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
|
||||
Used only if emb_dir is not provided.
|
||||
out_dir: Output directory for visualization plots (PDF)
|
||||
perplexity: t-SNE perplexity parameter (default: 10.0)
|
||||
color_by: What to color by - 'age', 'sex', or 'both' (default: 'age')
|
||||
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||
num_users: Include only first N users, 0 = all (default: 0)
|
||||
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||
This filters the already-loaded data, not which directories to load.
|
||||
"""
|
||||
# Validate color_by argument
|
||||
if color_by not in ["age", "sex", "both", "all"]:
|
||||
raise ValueError(f"Invalid color_by: {color_by}. Use 'age', 'sex', or 'both'.")
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Load embeddings: either from single directory or all label directories
|
||||
if emb_dir is not None:
|
||||
# Load from single directory
|
||||
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
|
||||
dataset = SBERT_Metadata.load_embeddings(emb_dir)
|
||||
else:
|
||||
# Load from all label directories
|
||||
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
|
||||
dataset = load_embeddings_from_all_labels(embeddings_root)
|
||||
|
||||
# Apply user filtering
|
||||
if users:
|
||||
user_list = [u.strip() for u in users.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
|
||||
print(f"[INFO] Filtered to users: {user_list}")
|
||||
|
||||
elif num_users > 0:
|
||||
all_users = sorted(set(dataset["user_id"]))
|
||||
selected_users = all_users[:num_users]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||
|
||||
# Filter by sleep stage labels
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["label"] in label_list)
|
||||
print(f"[INFO] Filtered to labels: {label_list}")
|
||||
|
||||
print(f"[INFO] Total samples: {len(dataset)}")
|
||||
|
||||
# Extract embeddings as numpy array for t-SNE
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
|
||||
# Extract ages and sexes for coloring
|
||||
ages = np.array(dataset["age"])
|
||||
sexes = np.array(dataset["sex"])
|
||||
|
||||
# Reduce to 2D with t-SNE
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||
|
||||
# Generate visualizations based on color_by parameter
|
||||
if color_by == "age":
|
||||
create_scatter_plot_by_age(
|
||||
coordinates_2d,
|
||||
ages,
|
||||
"t-SNE Visualization (Colored by Age)",
|
||||
os.path.join(out_dir, "tsne_by_age.pdf")
|
||||
)
|
||||
elif color_by == "sex":
|
||||
create_scatter_plot_by_sex(
|
||||
coordinates_2d,
|
||||
sexes,
|
||||
"t-SNE Visualization (Colored by Sex)",
|
||||
os.path.join(out_dir, "tsne_by_sex.pdf")
|
||||
)
|
||||
elif color_by == "both" or color_by == "all":
|
||||
# Create both plots
|
||||
create_scatter_plot_by_age(
|
||||
coordinates_2d,
|
||||
ages,
|
||||
"t-SNE Visualization (Colored by Age)",
|
||||
os.path.join(out_dir, "tsne_by_age.pdf")
|
||||
)
|
||||
create_scatter_plot_by_sex(
|
||||
coordinates_2d,
|
||||
sexes,
|
||||
"t-SNE Visualization (Colored by Sex)",
|
||||
os.path.join(out_dir, "tsne_by_sex.pdf")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(CLI)
|
||||
@@ -1,748 +0,0 @@
|
||||
"""
|
||||
SBERT Metadata-Based Embedding Extraction and Visualization Pipeline
|
||||
|
||||
This module provides functionality to
|
||||
1. Extract embeddings from user metadata (age and sex) using SBERT (Sentence-BERT)
|
||||
2. Visualize embeddings using dimensionality reduction (t-SNE) colored by age
|
||||
|
||||
SBERT Overview:
|
||||
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
|
||||
and triplet network structures to derive semantically meaningful sentence
|
||||
embeddings. It's designed for semantic similarity tasks and produces
|
||||
fixed-size dense vector representations of text.
|
||||
|
||||
Key Features:
|
||||
- Semantic understanding: Captures meaning rather than just word presence
|
||||
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
|
||||
- Efficient: Optimized for sentence-level tasks
|
||||
|
||||
Embedding Strategy:
|
||||
Instead of using time series features, we create textual descriptions based on
|
||||
user metadata (age and sex). This approach allows us to capture user-level
|
||||
characteristics in the embedding space.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Load user metadata from XLS file (age, sex)
|
||||
2. Textualization: Convert metadata to natural language description
|
||||
3. SBERT encoding: Generate 384-dimensional embeddings
|
||||
4. Visualization: t-SNE with continuous age coloring
|
||||
|
||||
Usage:
|
||||
# Extract embeddings from metadata for all labels
|
||||
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/all_labels
|
||||
|
||||
# Extract embeddings for a single label
|
||||
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/REM --label REM
|
||||
|
||||
# Visualize with t-SNE from all label directories (colored by age)
|
||||
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels
|
||||
|
||||
# Visualize from a single label directory
|
||||
python gen_plot.py plot --emb_dir ./embeddings/REM --out_dir ./plots/REM
|
||||
|
||||
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||
Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from datasets import load_from_disk, Dataset, concatenate_datasets
|
||||
from fire import Fire
|
||||
from sklearn.manifold import TSNE
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
SUBJECT_PATH = "LLM_Health/tsllm_personalization_icl/analysis/user_similarity/sbert_metadata_ppgbp/PPGBP_metadata.xlsx"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metadata Loading
|
||||
# =============================================================================
|
||||
|
||||
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Load subject metadata from XLS file.
|
||||
|
||||
The XLS file contains subject information including:
|
||||
- subject: Subject ID (e.g., "SC4001", "SC4002")
|
||||
- age: Age of the subject
|
||||
- sex (F=1): Sex (1 = Female, 0 = Male)
|
||||
|
||||
Args:
|
||||
subject_path: Path to the SC-subjects.xls file
|
||||
|
||||
Returns:
|
||||
Dictionary mapping subject IDs to metadata dictionaries
|
||||
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
|
||||
"""
|
||||
|
||||
df = pd.read_excel(subject_path, header=1)
|
||||
subject_info = {}
|
||||
|
||||
for index, row in df.iterrows():
|
||||
subject_id = str(row["subject_ID"]).strip()
|
||||
|
||||
subject_info[subject_id] = {
|
||||
"sex": str(row["Sex(M/F)"]).strip() if pd.notna(row["Sex(M/F)"]) else None,
|
||||
"age": int(row["Age(year)"]) if pd.notna(row["Age(year)"]) else None,
|
||||
"height": int(row["Height(cm)"]) if pd.notna(row["Height(cm)"]) else None,
|
||||
"weight": int(row["Weight(kg)"]) if pd.notna(row["Weight(kg)"]) else None,
|
||||
"sbp": int(row["Systolic Blood Pressure(mmHg)"])
|
||||
if pd.notna(row["Systolic Blood Pressure(mmHg)"]) else None,
|
||||
"dbp": int(row["Diastolic Blood Pressure(mmHg)"])
|
||||
if pd.notna(row["Diastolic Blood Pressure(mmHg)"]) else None,
|
||||
"hr": int(row["Heart Rate(b/m)"])
|
||||
if pd.notna(row["Heart Rate(b/m)"]) else None,
|
||||
"bmi": float(row["BMI(kg/m^2)"]) if pd.notna(row["BMI(kg/m^2)"]) else None,
|
||||
"hypertension": str(row["Hypertension"]) if pd.notna(row["Hypertension"]) else None,
|
||||
}
|
||||
|
||||
return subject_info
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Extractor Class
|
||||
# =============================================================================
|
||||
|
||||
class SBERT_Metadata:
|
||||
"""
|
||||
Extracts fixed-dimensional embeddings from user metadata using SBERT.
|
||||
|
||||
Uses Sentence-BERT to convert textualized metadata descriptions into dense
|
||||
vector representations. The textualization process converts sex, age, height,
|
||||
weight, sbp, dbp, heart rate, bmi, hypertension information into natural language,
|
||||
which SBERT then encodes semantically.
|
||||
|
||||
Architecture:
|
||||
User Metadata → Textualization → SBERT Encoder → Embedding
|
||||
|
||||
Processing Pipeline:
|
||||
1. Load user metadata (sex, age, height, weight, sbp, dbp, heart rate, bmi,
|
||||
hypertension) from XLS file
|
||||
2. Textualize: Convert metadata to natural language description
|
||||
3. SBERT encoding: Generate 384-dimensional semantic embeddings
|
||||
4. Output: Fixed-size embedding vector per user
|
||||
|
||||
Model: all-MiniLM-L6-v2
|
||||
- Lightweight BERT variant optimized for sentence embeddings
|
||||
- 384-dimensional output embeddings
|
||||
- Fast inference with good semantic understanding
|
||||
|
||||
Attributes:
|
||||
model: SentenceTransformer instance for encoding textualized metadata
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the SBERT embedder with pre-trained model.
|
||||
|
||||
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
|
||||
optimized for sentence similarity tasks.
|
||||
"""
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
def textualize_metadata_ppg_bp(self,
|
||||
sex: Optional[str],
|
||||
age: Optional[int],
|
||||
height: Optional[int],
|
||||
weight: Optional[int],
|
||||
sbp: Optional[int],
|
||||
dbp: Optional[int],
|
||||
heart_rate: Optional[int],
|
||||
bmi: Optional[float],
|
||||
hypertension: Optional[str]) -> str:
|
||||
"""
|
||||
Convert user metadata (age, sex, height, weight, sbp, dbp, heart rate, bmi, hypertension) into a natural language description.
|
||||
|
||||
This textualization step is crucial for SBERT, which expects text input.
|
||||
The description provides physiological and demographic information in a structured format.
|
||||
|
||||
Args:
|
||||
sex: Sex of the user (0 = Male, 1 = Female, may be None)
|
||||
age: Age of the user (integer, may be None)
|
||||
height_cm: Height in centimeters, (integer, may be None)
|
||||
weight_kg: Weight in kilograms, (integer, may be None)
|
||||
sbp_mmHg: Systolic blood pressure in mmHg, (integer, may be None)
|
||||
dbp_mmHg: Diastolic blood pressure in mmHg, (integer, may be None)
|
||||
heart_rate_bpm: Heart rate in beats per minute, (integer, may be None)
|
||||
bmi: Body mass index (kg/m^2), (float, may be None)
|
||||
hypertension: Hypertension status (0 = Normal, 1 = Prehypertension, 2 = Stage 1 hypertension, 3 = Stage 2 hypertension, may be None)
|
||||
|
||||
|
||||
Returns:
|
||||
Natural language string describing the user metadata
|
||||
"""
|
||||
# Map sex code to text
|
||||
if sex is not None:
|
||||
sex_text = "Female" if sex == 1 else "Male"
|
||||
else:
|
||||
sex_text = "Unknown"
|
||||
|
||||
# Map age code to text
|
||||
if age is not None:
|
||||
age_text = f"{age}"
|
||||
else:
|
||||
age_text = "unknown"
|
||||
|
||||
# Map height code to text
|
||||
if height is not None:
|
||||
height_text = f"{height} cm"
|
||||
else:
|
||||
height_text = "unknown"
|
||||
|
||||
# Map weight code to text
|
||||
if weight is not None:
|
||||
weight_text = f"{weight} kg"
|
||||
else:
|
||||
weight_text = "unknown"
|
||||
|
||||
# Map sbp code to text
|
||||
if sbp is not None:
|
||||
sbp_text = f"{sbp} mmHg"
|
||||
else:
|
||||
sbp_text = "unknown"
|
||||
|
||||
# Map dbp code to text
|
||||
if dbp is not None:
|
||||
dbp_text = f"{dbp} mmHg"
|
||||
else:
|
||||
dbp_text = "unknown"
|
||||
|
||||
# Map heart rate code to text
|
||||
if heart_rate is not None:
|
||||
heart_rate_text = f"{heart_rate} bpm"
|
||||
else:
|
||||
heart_rate_text = "unknown"
|
||||
|
||||
# Map bmi code to text
|
||||
if bmi is not None:
|
||||
bmi_text = f"{bmi} kg/m^2"
|
||||
else:
|
||||
bmi_text = "unknown"
|
||||
|
||||
# Map hypertension code to text
|
||||
if hypertension is not None:
|
||||
hypertension_text = f"{hypertension}"
|
||||
else:
|
||||
hypertension_text = "unknown"
|
||||
|
||||
# Create sentence from metadata
|
||||
if age is not None:
|
||||
sentence = f"This is the information of the user, sex: {sex_text}, age: {age_text}, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
|
||||
else:
|
||||
sentence = f"This is the information of the user, sex: {sex_text}, age: unknown, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
|
||||
|
||||
return sentence
|
||||
|
||||
|
||||
def compute_embedding_from_metadata(self,
|
||||
sexes: List[Optional[int]],
|
||||
ages: List[Optional[int]],
|
||||
heights: List[Optional[float]],
|
||||
weights: List[Optional[float]],
|
||||
systolics: List[Optional[float]],
|
||||
diastolics: List[Optional[float]],
|
||||
heart_rates: List[Optional[float]],
|
||||
bmis: List[Optional[float]],
|
||||
hypertensions: List[Optional[int]]) -> np.ndarray:
|
||||
|
||||
"""
|
||||
Generate embedding vectors from metadata using SBERT.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Textualize each user's metadata into natural language
|
||||
2. Encode textual descriptions using SBERT
|
||||
3. Return fixed-size embedding vectors
|
||||
|
||||
Args:
|
||||
sexes: List of sex values (0 = Male, 1 = Female, may contain None)
|
||||
ages: List of age values (may contain None)
|
||||
heights: List of height values (may contain None)
|
||||
weights: List of weight values (may contain None)
|
||||
systolics: List of systolic blood pressure values (may contain None)
|
||||
diastolics: List of diastolic blood pressure values (may contain None)
|
||||
heart_rates: List of heart rate values (may contain None)
|
||||
bmis: List of body mass index values (may contain None)
|
||||
hypertensions: List of hypertension values (0 = Normal, 1 = Prehypertension,
|
||||
2 = Stage 1 hypertension, 3 = Stage 2 hypertension, may contain None)
|
||||
|
||||
Returns:
|
||||
Embedding array of shape (batch_size, embedding_dim)
|
||||
For all-MiniLM-L6-v2: (batch_size, 384)
|
||||
"""
|
||||
# Convert metadata to text sentences
|
||||
text_samples = []
|
||||
for sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension in zip(sexes, ages, heights, weights, systolics, diastolics, heart_rates, bmis, hypertensions):
|
||||
text_samples.append(self.textualize_metadata_ppg_bp(sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension))
|
||||
|
||||
# Encode text descriptions using SBERT
|
||||
# Returns numpy array of shape (batch_size, 384)
|
||||
embeddings = self.model.encode(text_samples)
|
||||
|
||||
return embeddings
|
||||
|
||||
def extract_embeddings(
|
||||
self,
|
||||
data_root: str,
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
batch_size: int = 32,
|
||||
label: Optional[str] = None
|
||||
) -> Dataset:
|
||||
"""
|
||||
Extract embeddings from user metadata for all sessions.
|
||||
|
||||
Iterates through all user/session combinations, loads metadata for each user,
|
||||
generates embeddings from metadata sentences, and aggregates results.
|
||||
Can process a single label or all labels.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
subject_path: Path to SC-subjects.xls file with metadata
|
||||
batch_size: Number of samples to process together (for batching embeddings)
|
||||
Larger = faster but more memory.
|
||||
32 is a good balance for most systems.
|
||||
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM").
|
||||
If None or "all", processes all labels.
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with columns:
|
||||
- user_id, idx, label (metadata from original data)
|
||||
- embedding (384-dim vector from all-MiniLM-L6-v2 based on metadata)
|
||||
"""
|
||||
# Load subject metadata
|
||||
print(f"[INFO] Loading subject metadata from: {subject_path}")
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
|
||||
|
||||
|
||||
all_embeddings = []
|
||||
all_user_ids = []
|
||||
# all_session_ids = []
|
||||
all_idxs = []
|
||||
all_labels = []
|
||||
all_sexes = []
|
||||
all_ages = []
|
||||
all_heights = []
|
||||
all_weights = []
|
||||
all_systolics = []
|
||||
all_diastolics = []
|
||||
all_heart_rates = []
|
||||
all_bmis = []
|
||||
all_hypertensions = []
|
||||
|
||||
# Collect metadata for all samples first
|
||||
for user_id, session_path in session_paths:
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(session_path)
|
||||
# Shuffle dataset for randomness
|
||||
dataset = dataset.shuffle(seed=0)
|
||||
# Filter by sleep stage label if specified
|
||||
if label is not None and label != "all":
|
||||
dataset = dataset.filter(lambda x: x["label"] == label)
|
||||
num_samples = len(dataset)
|
||||
|
||||
if num_samples == 0:
|
||||
continue
|
||||
|
||||
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
|
||||
|
||||
# Get metadata for this user
|
||||
# Convert user_id to string format that matches metadata keys
|
||||
user_id_str = str(int(user_id))
|
||||
try:
|
||||
sex = subject_metadata[user_id_str]["sex"]
|
||||
age = subject_metadata[user_id_str]["age"]
|
||||
height = subject_metadata[user_id_str]["height"]
|
||||
weight = subject_metadata[user_id_str]["weight"]
|
||||
sbp = subject_metadata[user_id_str]["sbp"]
|
||||
dbp = subject_metadata[user_id_str]["dbp"]
|
||||
heart_rate = subject_metadata[user_id_str]["heart_rate"]
|
||||
bmi = subject_metadata[user_id_str]["bmi"]
|
||||
hypertension = subject_metadata[user_id_str]["hypertension"]
|
||||
|
||||
|
||||
except KeyError:
|
||||
sex = None
|
||||
age = None
|
||||
height = None
|
||||
weight = None
|
||||
sbp = None
|
||||
dbp = None
|
||||
heart_rate = None
|
||||
bmi = None
|
||||
hypertension = None
|
||||
print(f"[WARN] No metadata found for user_id: {user_id_str}")
|
||||
|
||||
# Collect all samples for this user/session
|
||||
for i in range(num_samples):
|
||||
all_user_ids.append(str(dataset["user_id"][i]))
|
||||
all_session_ids.append(str(dataset["session_id"][i]))
|
||||
all_idxs.append(int(dataset["idx"][i]))
|
||||
all_labels.append(str(dataset["label"][i]))
|
||||
all_ages.append(age)
|
||||
all_sexes.append(sex)
|
||||
|
||||
# Generate embeddings from metadata in batches
|
||||
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
|
||||
for batch_start in range(0, len(all_ages), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(all_ages))
|
||||
|
||||
batch_sexes = all_sexes[batch_start:batch_end]
|
||||
batch_ages = all_ages[batch_start:batch_end]
|
||||
|
||||
# Compute embeddings from metadata
|
||||
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
|
||||
|
||||
# Collect embeddings
|
||||
for i in range(embeddings.shape[0]):
|
||||
all_embeddings.append(embeddings[i].tolist())
|
||||
|
||||
# Create HuggingFace Dataset
|
||||
result_dataset = Dataset.from_dict({
|
||||
"user_id": all_user_ids,
|
||||
# "session_id": all_session_ids,
|
||||
"idx": all_idxs,
|
||||
"label": all_labels,
|
||||
"sex": all_sexes,
|
||||
"age": all_ages,
|
||||
"height": all_heights,
|
||||
"weight": all_weights,
|
||||
"systolic": all_systolics,
|
||||
"diastolic": all_diastolics,
|
||||
"heart_rate": all_heart_rates,
|
||||
"bmi": all_bmis,
|
||||
"hypertension": all_hypertensions,
|
||||
"embedding": all_embeddings,
|
||||
})
|
||||
|
||||
print(f"[INFO] Total samples: {len(result_dataset)}")
|
||||
|
||||
# Print label distribution
|
||||
if "label" in result_dataset.column_names:
|
||||
label_counts = {}
|
||||
for lbl in result_dataset["label"]:
|
||||
label_counts[lbl] = label_counts.get(lbl, 0) + 1
|
||||
print(f"[INFO] Label distribution: {label_counts}")
|
||||
|
||||
return result_dataset
|
||||
|
||||
def save_embeddings(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
output_dir: str
|
||||
) -> None:
|
||||
"""
|
||||
Save embeddings dataset to disk in HuggingFace format.
|
||||
|
||||
Args:
|
||||
dataset: HuggingFace Dataset containing embeddings and metadata
|
||||
output_dir: Directory path to save the dataset
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
dataset.save_to_disk(output_dir)
|
||||
|
||||
print(f"[DONE] Saved embeddings dataset: {output_dir}")
|
||||
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
|
||||
|
||||
@staticmethod
|
||||
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||
"""
|
||||
Load saved embeddings dataset from disk.
|
||||
|
||||
Args:
|
||||
embedding_dir: Directory path containing the saved HuggingFace dataset
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with embeddings and metadata
|
||||
"""
|
||||
dataset = load_from_disk(embedding_dir)
|
||||
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading Utilities
|
||||
# =============================================================================
|
||||
|
||||
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
|
||||
"""
|
||||
Load embeddings from all label subdirectories and concatenate them.
|
||||
|
||||
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
|
||||
and loads embeddings from each, then concatenates them into a single dataset.
|
||||
|
||||
Args:
|
||||
embeddings_root: Root directory containing label subdirectories
|
||||
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
|
||||
|
||||
Returns:
|
||||
Concatenated HuggingFace Dataset with all labels combined
|
||||
"""
|
||||
# Discover all label subdirectories
|
||||
label_dirs = []
|
||||
for item in os.listdir(embeddings_root):
|
||||
item_path = os.path.join(embeddings_root, item)
|
||||
if os.path.isdir(item_path):
|
||||
# Check if it's a valid HuggingFace dataset directory
|
||||
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
|
||||
label_dirs.append((item, item_path))
|
||||
|
||||
if len(label_dirs) == 0:
|
||||
raise ValueError(
|
||||
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
|
||||
)
|
||||
|
||||
label_dirs.sort() # Sort for consistent ordering
|
||||
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
|
||||
|
||||
# Load datasets from each label directory
|
||||
datasets = []
|
||||
for label_name, label_path in label_dirs:
|
||||
print(f"[INFO] Loading embeddings from: {label_path}")
|
||||
dataset = load_from_disk(label_path)
|
||||
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
|
||||
datasets.append(dataset)
|
||||
|
||||
# Concatenate all datasets
|
||||
if len(datasets) == 1:
|
||||
combined_dataset = datasets[0]
|
||||
else:
|
||||
combined_dataset = concatenate_datasets(datasets)
|
||||
|
||||
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
|
||||
|
||||
# Print label distribution
|
||||
if "label" in combined_dataset.column_names:
|
||||
label_counts = {}
|
||||
for label in combined_dataset["label"]:
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
print(f"[INFO] Label distribution: {label_counts}")
|
||||
|
||||
return combined_dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization Functions
|
||||
# =============================================================================
|
||||
|
||||
def reduce_to_2d_tsne(
|
||||
embeddings: np.ndarray,
|
||||
perplexity: float = 30.0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Reduce high-dimensional embeddings to 2D using t-SNE.
|
||||
|
||||
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
|
||||
dimensionality reduction technique that preserves local structure.
|
||||
Points that are similar in high dimensions stay close in 2D.
|
||||
|
||||
Args:
|
||||
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
|
||||
perplexity: t-SNE perplexity parameter (typically 5-50).
|
||||
Higher values consider more neighbors, creating smoother layouts.
|
||||
Rule of thumb: perplexity ~ sqrt(num_samples)
|
||||
|
||||
Returns:
|
||||
2D coordinates of shape (num_samples, 2)
|
||||
"""
|
||||
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
|
||||
|
||||
tsne = TSNE(
|
||||
n_components=2,
|
||||
random_state=0, # For reproducibility
|
||||
perplexity=perplexity,
|
||||
max_iter=1000, # Usually sufficient for convergence
|
||||
init='random',
|
||||
learning_rate='auto', # Let sklearn choose optimal learning rate
|
||||
)
|
||||
|
||||
return tsne.fit_transform(embeddings)
|
||||
|
||||
|
||||
def create_scatter_plot_by_age(
|
||||
coordinates: np.ndarray,
|
||||
ages: np.ndarray,
|
||||
title: str,
|
||||
output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot colored by age (continuous colormap).
|
||||
|
||||
Uses a continuous colormap (viridis) to show age distribution.
|
||||
Ages are mapped to colors on a gradient scale.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
ages: Age values for each point (array, may contain None values)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF vector format recommended)
|
||||
"""
|
||||
# Filter out points with missing age data
|
||||
# Convert None values to NaN for proper numpy handling
|
||||
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
|
||||
valid_mask = ~np.isnan(ages_float)
|
||||
valid_coords = coordinates[valid_mask]
|
||||
valid_ages = ages_float[valid_mask]
|
||||
|
||||
if len(valid_ages) == 0:
|
||||
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
|
||||
return
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Create scatter plot with continuous colormap
|
||||
scatter = ax.scatter(
|
||||
valid_coords[:, 0],
|
||||
valid_coords[:, 1],
|
||||
c=valid_ages,
|
||||
cmap='viridis', # Continuous colormap for granular age visualization
|
||||
s=15,
|
||||
alpha=0.7,
|
||||
edgecolors='none',
|
||||
)
|
||||
|
||||
# Add colorbar
|
||||
cbar = plt.colorbar(scatter, ax=ax)
|
||||
cbar.set_label('Age (years)', rotation=270, labelpad=20)
|
||||
|
||||
# Configure plot appearance
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
|
||||
# Save figure as vector PDF
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
|
||||
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
class CLI:
|
||||
"""
|
||||
Command-line interface for SBERT metadata-based embedding extraction and visualization.
|
||||
|
||||
Provides two main commands:
|
||||
- extract: Generate embeddings from user metadata (age, sex)
|
||||
- plot: Visualize embeddings with t-SNE colored by age
|
||||
"""
|
||||
|
||||
def extract(
|
||||
self,
|
||||
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
out_dir: str = "./embeddings/all_labels",
|
||||
batch_size: int = 32,
|
||||
label: str = None
|
||||
) -> None:
|
||||
"""
|
||||
Extract embeddings from user metadata.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
subject_path: Path to SC-subjects.xls file with metadata
|
||||
out_dir: Output directory for HuggingFace dataset
|
||||
batch_size: Batch size for inference (default: 32)
|
||||
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
|
||||
If None or 'all', processes all labels (default: None for all labels)
|
||||
"""
|
||||
embedder = SBERT_Metadata()
|
||||
dataset = embedder.extract_embeddings(data_root, subject_path, batch_size, label)
|
||||
embedder.save_embeddings(dataset, out_dir)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
emb_dir: str = None,
|
||||
embeddings_root: str = "./embeddings",
|
||||
out_dir: str = "./plots/all_labels",
|
||||
perplexity: float = 10.0,
|
||||
users: str = None,
|
||||
num_users: int = 0,
|
||||
labels: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize embeddings with t-SNE, colored by age.
|
||||
|
||||
Can load from either a single label directory or all label directories.
|
||||
|
||||
Args:
|
||||
emb_dir: Single directory containing the HuggingFace embeddings dataset
|
||||
(e.g., "./embeddings/REM"). If provided, only this directory is used.
|
||||
embeddings_root: Root directory containing label subdirectories
|
||||
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
|
||||
Used only if emb_dir is not provided.
|
||||
out_dir: Output directory for visualization plots (PDF)
|
||||
perplexity: t-SNE perplexity parameter (default: 10.0)
|
||||
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||
num_users: Include only first N users, 0 = all (default: 0)
|
||||
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||
This filters the already-loaded data, not which directories to load.
|
||||
"""
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Load embeddings: either from single directory or all label directories
|
||||
if emb_dir is not None:
|
||||
# Load from single directory
|
||||
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
|
||||
dataset = SBERT_Metadata.load_embeddings(emb_dir)
|
||||
else:
|
||||
# Load from all label directories
|
||||
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
|
||||
dataset = load_embeddings_from_all_labels(embeddings_root)
|
||||
|
||||
# Apply user filtering
|
||||
if users:
|
||||
user_list = [u.strip() for u in users.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
|
||||
print(f"[INFO] Filtered to users: {user_list}")
|
||||
|
||||
elif num_users > 0:
|
||||
all_users = sorted(set(dataset["user_id"]))
|
||||
selected_users = all_users[:num_users]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||
|
||||
# Filter by sleep stage labels
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["label"] in label_list)
|
||||
print(f"[INFO] Filtered to labels: {label_list}")
|
||||
|
||||
print(f"[INFO] Total samples: {len(dataset)}")
|
||||
|
||||
# Extract embeddings as numpy array for t-SNE
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
|
||||
# Extract ages for coloring
|
||||
ages = np.array(dataset["age"])
|
||||
|
||||
# Reduce to 2D with t-SNE
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||
|
||||
# Generate visualization colored by age
|
||||
create_scatter_plot_by_age(
|
||||
coordinates_2d,
|
||||
ages,
|
||||
"t-SNE Visualization (Colored by Age)",
|
||||
os.path.join(out_dir, "tsne_by_age.pdf")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(CLI)
|
||||
58
baselines/common.py
Normal file
58
baselines/common.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Shared setup for baseline experiments."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import yaml
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
from core.data_loader import DataLoader # noqa: E402
|
||||
from core.model import load_models # noqa: E402
|
||||
from core.recruiter import Recruiter # noqa: E402
|
||||
from core.agent import Agent # noqa: E402
|
||||
from core.logger import Logger # noqa: E402
|
||||
from core.prompt import gen_system_message, gen_task_message # noqa: E402
|
||||
from core.json_utils import safe_parse_json # noqa: E402
|
||||
from core.scores import self_certainty # noqa: E402
|
||||
from core.vote import borda_vote # noqa: E402
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
|
||||
def setup(config: dict, temperature: float = None):
|
||||
"""
|
||||
Initialize all components from config.
|
||||
|
||||
Args:
|
||||
config: Parsed YAML config dict.
|
||||
temperature: Override model temperature (e.g. 0.7 for SC baseline).
|
||||
If None, uses config["temperature"] or 0.0.
|
||||
Returns:
|
||||
(logger, dataloader, recruiter, agent)
|
||||
"""
|
||||
logger = Logger(config.get("log_path"))
|
||||
logger.log_config(config)
|
||||
|
||||
dataloader = DataLoader(config.get("data_path"), config.get("target_user"))
|
||||
|
||||
recruiter = Recruiter(
|
||||
source_dataset=dataloader.get_source_dataset(),
|
||||
source_users=dataloader.get_source_users(),
|
||||
num_shot=config.get("num_shot"),
|
||||
classes=dataloader.get_classes(),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
temp = temperature if temperature is not None else config.get("temperature", 0.0)
|
||||
model_pool = load_models(config.get("model_paths"), temperature=temp)
|
||||
|
||||
agent = Agent(model_pool=model_pool, logger=logger)
|
||||
system_message = gen_system_message(metadata=dataloader.get_task_metadata())
|
||||
agent.set_system_message(system_message)
|
||||
|
||||
return logger, dataloader, recruiter, agent
|
||||
72
baselines/random_dynamic_borda.py
Normal file
72
baselines/random_dynamic_borda.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Baseline 4: Random Examples + Borda Voting
|
||||
|
||||
For each sample, randomly recruit queue_size fresh example sets.
|
||||
Each example set produces one inference. Answers aggregated via borda voting.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from fire import Fire
|
||||
from common import (
|
||||
load_config, setup,
|
||||
gen_task_message, safe_parse_json, self_certainty, borda_vote,
|
||||
)
|
||||
|
||||
|
||||
async def run(config_path: str):
|
||||
config = load_config(config_path)
|
||||
queue_size = config.get("queue_size")
|
||||
logger, dataloader, recruiter, agent = setup(config)
|
||||
|
||||
logger.log("[Baseline] Random examples + borda voting")
|
||||
|
||||
async def process(sample_idx, example_idx, example_set, sample):
|
||||
try:
|
||||
task_message = gen_task_message(sample, example_set)
|
||||
response, logprobs = await agent.solve(
|
||||
task_message, sample_idx, example_idx
|
||||
)
|
||||
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
|
||||
parsed = safe_parse_json(response)
|
||||
answer = parsed.get("ANSWER") if parsed else None
|
||||
logger.log(
|
||||
f"[Main] Done {sample_idx} - {example_idx}: "
|
||||
f"answer={answer}, score={score:.4f}"
|
||||
)
|
||||
return {"answer": answer, "score": score}
|
||||
except Exception as e:
|
||||
logger.log(
|
||||
f"[Main] Error {sample_idx} - {example_idx}: {e}",
|
||||
filename="errors.txt",
|
||||
)
|
||||
return {"answer": None, "score": float("-inf")}
|
||||
|
||||
start_time = time.time()
|
||||
for idx, sample in enumerate(dataloader):
|
||||
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
|
||||
example_sets = recruiter.recruit(queue_size)
|
||||
tasks = [
|
||||
process(idx, i, es, sample) for i, es in enumerate(example_sets)
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
|
||||
ground_truth = sample["label"]
|
||||
if winner is not None:
|
||||
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
|
||||
logger.log(f"[Vote] votes={{ {tally_str} }}")
|
||||
logger.log_result(idx, winner, ground_truth)
|
||||
else:
|
||||
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
|
||||
|
||||
logger.report(elapsed_seconds=time.time() - start_time)
|
||||
|
||||
|
||||
def main(config_path: str):
|
||||
asyncio.run(run(config_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
69
baselines/random_dynamic_sc.py
Normal file
69
baselines/random_dynamic_sc.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Baseline 6: Random Dynamic Self-Consistency
|
||||
|
||||
For each sample, randomly recruit ONE fresh example set,
|
||||
then run that same prompt queue_size times with temperature=0.7.
|
||||
Answers are aggregated via borda voting.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from fire import Fire
|
||||
from common import (
|
||||
load_config, setup,
|
||||
gen_task_message, safe_parse_json, self_certainty, borda_vote,
|
||||
)
|
||||
|
||||
|
||||
async def run(config_path: str):
|
||||
config = load_config(config_path)
|
||||
queue_size = config.get("queue_size")
|
||||
logger, dataloader, recruiter, agent = setup(config, temperature=0.7)
|
||||
|
||||
logger.log("[Baseline] Random dynamic example + self-consistency (temp=0.7)")
|
||||
|
||||
async def run_once(sample_idx, run_idx, task_message):
|
||||
try:
|
||||
response, logprobs = await agent.solve(task_message, sample_idx, run_idx)
|
||||
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
|
||||
parsed = safe_parse_json(response)
|
||||
answer = parsed.get("ANSWER") if parsed else None
|
||||
logger.log(
|
||||
f"[SC] {sample_idx} - run {run_idx}: "
|
||||
f"answer={answer}, score={score:.4f}"
|
||||
)
|
||||
return {"answer": answer, "score": score}
|
||||
except Exception as e:
|
||||
logger.log(
|
||||
f"[SC] Error {sample_idx} - run {run_idx}: {e}",
|
||||
filename="errors.txt",
|
||||
)
|
||||
return {"answer": None, "score": float("-inf")}
|
||||
|
||||
start_time = time.time()
|
||||
for idx, sample in enumerate(dataloader):
|
||||
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
|
||||
example_set = recruiter.recruit(1)[0]
|
||||
task_message = gen_task_message(sample, example_set)
|
||||
tasks = [run_once(idx, i, task_message) for i in range(queue_size)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
|
||||
ground_truth = sample["label"]
|
||||
if winner is not None:
|
||||
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
|
||||
logger.log(f"[Vote] votes={{ {tally_str} }}")
|
||||
logger.log_result(idx, winner, ground_truth)
|
||||
else:
|
||||
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
|
||||
|
||||
logger.report(elapsed_seconds=time.time() - start_time)
|
||||
|
||||
|
||||
def main(config_path: str):
|
||||
asyncio.run(run(config_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
51
baselines/random_dynamic_single.py
Normal file
51
baselines/random_dynamic_single.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Baseline 2: Random Dynamic Single Example
|
||||
|
||||
For each sample, randomly select a fresh example set.
|
||||
Single inference per sample, no voting.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from fire import Fire
|
||||
from common import (
|
||||
load_config, setup,
|
||||
gen_task_message, safe_parse_json, self_certainty,
|
||||
)
|
||||
|
||||
|
||||
async def run(config_path: str):
|
||||
config = load_config(config_path)
|
||||
logger, dataloader, recruiter, agent = setup(config)
|
||||
|
||||
start_time = time.time()
|
||||
for idx, sample in enumerate(dataloader):
|
||||
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
|
||||
ground_truth = sample["label"]
|
||||
try:
|
||||
example_set = recruiter.recruit(1)[0]
|
||||
task_message = gen_task_message(sample, example_set)
|
||||
response, logprobs = await agent.solve(task_message, idx, 0)
|
||||
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
|
||||
parsed = safe_parse_json(response)
|
||||
answer = parsed.get("ANSWER") if parsed else None
|
||||
logger.log(f"[Main] answer={answer}, score={score:.4f}")
|
||||
except Exception as e:
|
||||
logger.log(f"[Main] Error sample {idx}: {e}", filename="errors.txt")
|
||||
answer = None
|
||||
|
||||
if answer is not None:
|
||||
logger.log_result(idx, answer, ground_truth)
|
||||
else:
|
||||
logger.log(f"[Main] Sample {idx}: no valid answer, skipping")
|
||||
|
||||
logger.report(elapsed_seconds=time.time() - start_time)
|
||||
|
||||
|
||||
def main(config_path: str):
|
||||
asyncio.run(run(config_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
75
baselines/random_fixed_borda.py
Normal file
75
baselines/random_fixed_borda.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Baseline 5: Fixed Examples + Borda Voting
|
||||
|
||||
Recruit queue_size example sets once at the start.
|
||||
For each sample, run all example sets through the LLM.
|
||||
Answers aggregated via borda voting. No queue updates.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from fire import Fire
|
||||
from common import (
|
||||
load_config,
|
||||
setup,
|
||||
gen_task_message,
|
||||
safe_parse_json,
|
||||
self_certainty,
|
||||
borda_vote,
|
||||
)
|
||||
|
||||
|
||||
async def run(config_path: str):
|
||||
config = load_config(config_path)
|
||||
queue_size = config.get("queue_size")
|
||||
logger, dataloader, recruiter, agent = setup(config)
|
||||
|
||||
example_sets = recruiter.recruit(queue_size)
|
||||
logger.log(f"[Baseline] Fixed {queue_size} example set(s) + borda voting")
|
||||
|
||||
async def process(sample_idx, example_idx, example_set, sample):
|
||||
try:
|
||||
task_message = gen_task_message(sample, example_set)
|
||||
response, logprobs = await agent.solve(
|
||||
task_message, sample_idx, example_idx
|
||||
)
|
||||
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
|
||||
parsed = safe_parse_json(response)
|
||||
answer = parsed.get("ANSWER") if parsed else None
|
||||
logger.log(
|
||||
f"[Main] Done {sample_idx} - {example_idx}: "
|
||||
f"answer={answer}, score={score:.4f}"
|
||||
)
|
||||
return {"answer": answer, "score": score}
|
||||
except Exception as e:
|
||||
logger.log(
|
||||
f"[Main] Error {sample_idx} - {example_idx}: {e}",
|
||||
filename="errors.txt",
|
||||
)
|
||||
return {"answer": None, "score": float("-inf")}
|
||||
|
||||
start_time = time.time()
|
||||
for idx, sample in enumerate(dataloader):
|
||||
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
|
||||
tasks = [process(idx, i, es, sample) for i, es in enumerate(example_sets)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
|
||||
ground_truth = sample["label"]
|
||||
if winner is not None:
|
||||
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
|
||||
logger.log(f"[Vote] votes={{ {tally_str} }}")
|
||||
logger.log_result(idx, winner, ground_truth)
|
||||
else:
|
||||
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
|
||||
|
||||
logger.report(elapsed_seconds=time.time() - start_time)
|
||||
|
||||
|
||||
def main(config_path: str):
|
||||
asyncio.run(run(config_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
69
baselines/random_fixed_sc.py
Normal file
69
baselines/random_fixed_sc.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Baseline 3: Self-Consistency with Fixed Example
|
||||
|
||||
One randomly selected example set, used for ALL samples.
|
||||
For each sample, the SAME prompt is run queue_size times with temperature=0.7
|
||||
across different LLM instances. Answers are aggregated via borda voting.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from fire import Fire
|
||||
from common import (
|
||||
load_config, setup,
|
||||
gen_task_message, safe_parse_json, self_certainty, borda_vote,
|
||||
)
|
||||
|
||||
|
||||
async def run(config_path: str):
|
||||
config = load_config(config_path)
|
||||
queue_size = config.get("queue_size")
|
||||
logger, dataloader, recruiter, agent = setup(config, temperature=0.7)
|
||||
|
||||
example_set = recruiter.recruit(1)[0]
|
||||
logger.log("[Baseline] Fixed example set + self-consistency (temp=0.7)")
|
||||
|
||||
async def run_once(sample_idx, run_idx, task_message):
|
||||
try:
|
||||
response, logprobs = await agent.solve(task_message, sample_idx, run_idx)
|
||||
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
|
||||
parsed = safe_parse_json(response)
|
||||
answer = parsed.get("ANSWER") if parsed else None
|
||||
logger.log(
|
||||
f"[SC] {sample_idx} - run {run_idx}: "
|
||||
f"answer={answer}, score={score:.4f}"
|
||||
)
|
||||
return {"answer": answer, "score": score}
|
||||
except Exception as e:
|
||||
logger.log(
|
||||
f"[SC] Error {sample_idx} - run {run_idx}: {e}",
|
||||
filename="errors.txt",
|
||||
)
|
||||
return {"answer": None, "score": float("-inf")}
|
||||
|
||||
start_time = time.time()
|
||||
for idx, sample in enumerate(dataloader):
|
||||
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
|
||||
task_message = gen_task_message(sample, example_set)
|
||||
tasks = [run_once(idx, i, task_message) for i in range(queue_size)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
|
||||
ground_truth = sample["label"]
|
||||
if winner is not None:
|
||||
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
|
||||
logger.log(f"[Vote] votes={{ {tally_str} }}")
|
||||
logger.log_result(idx, winner, ground_truth)
|
||||
else:
|
||||
logger.log(f"[Vote] Sample {idx}: no valid answer, skipping")
|
||||
|
||||
logger.report(elapsed_seconds=time.time() - start_time)
|
||||
|
||||
|
||||
def main(config_path: str):
|
||||
asyncio.run(run(config_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
53
baselines/random_fixed_single.py
Normal file
53
baselines/random_fixed_single.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
Baseline 1: Fixed Single Example
|
||||
|
||||
One randomly selected example set, used for ALL samples.
|
||||
Single inference per sample, no voting.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from fire import Fire
|
||||
from common import (
|
||||
load_config, setup,
|
||||
gen_task_message, safe_parse_json, self_certainty,
|
||||
)
|
||||
|
||||
|
||||
async def run(config_path: str):
|
||||
config = load_config(config_path)
|
||||
logger, dataloader, recruiter, agent = setup(config)
|
||||
|
||||
example_set = recruiter.recruit(1)[0]
|
||||
logger.log("[Baseline] Fixed single example set selected")
|
||||
|
||||
start_time = time.time()
|
||||
for idx, sample in enumerate(dataloader):
|
||||
logger.log(f"[Main] Sample {idx} / {len(dataloader)}")
|
||||
ground_truth = sample["label"]
|
||||
try:
|
||||
task_message = gen_task_message(sample, example_set)
|
||||
response, logprobs = await agent.solve(task_message, idx, 0)
|
||||
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
|
||||
parsed = safe_parse_json(response)
|
||||
answer = parsed.get("ANSWER") if parsed else None
|
||||
logger.log(f"[Main] answer={answer}, score={score:.4f}")
|
||||
except Exception as e:
|
||||
logger.log(f"[Main] Error sample {idx}: {e}", filename="errors.txt")
|
||||
answer = None
|
||||
|
||||
if answer is not None:
|
||||
logger.log_result(idx, answer, ground_truth)
|
||||
else:
|
||||
logger.log(f"[Main] Sample {idx}: no valid answer, skipping")
|
||||
|
||||
logger.report(elapsed_seconds=time.time() - start_time)
|
||||
|
||||
|
||||
def main(config_path: str):
|
||||
asyncio.run(run(config_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
17
config/ours/sleepedf/user00.yaml
Normal file
17
config/ours/sleepedf/user00.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '00'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/ours/sleepedf/00
|
||||
17
config/ours/sleepedf/user01.yaml
Normal file
17
config/ours/sleepedf/user01.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '01'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/ours/sleepedf/01
|
||||
17
config/ours/sleepedf/user02.yaml
Normal file
17
config/ours/sleepedf/user02.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '02'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/ours/sleepedf/02
|
||||
17
config/random_dynamic_borda/sleepedf/user00.yaml
Normal file
17
config/random_dynamic_borda/sleepedf/user00.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '00'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_borda/sleepedf/00
|
||||
17
config/random_dynamic_borda/sleepedf/user01.yaml
Normal file
17
config/random_dynamic_borda/sleepedf/user01.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '01'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_borda/sleepedf/01
|
||||
17
config/random_dynamic_borda/sleepedf/user02.yaml
Normal file
17
config/random_dynamic_borda/sleepedf/user02.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '02'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_borda/sleepedf/02
|
||||
17
config/random_dynamic_sc/sleepedf/user00.yaml
Normal file
17
config/random_dynamic_sc/sleepedf/user00.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '00'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_sc/sleepedf/00
|
||||
17
config/random_dynamic_sc/sleepedf/user01.yaml
Normal file
17
config/random_dynamic_sc/sleepedf/user01.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '01'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_sc/sleepedf/01
|
||||
17
config/random_dynamic_sc/sleepedf/user02.yaml
Normal file
17
config/random_dynamic_sc/sleepedf/user02.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '02'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_sc/sleepedf/02
|
||||
17
config/random_dynamic_single/sleepedf/user00.yaml
Normal file
17
config/random_dynamic_single/sleepedf/user00.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '00'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_single/sleepedf/00
|
||||
17
config/random_dynamic_single/sleepedf/user01.yaml
Normal file
17
config/random_dynamic_single/sleepedf/user01.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '01'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_single/sleepedf/01
|
||||
17
config/random_dynamic_single/sleepedf/user02.yaml
Normal file
17
config/random_dynamic_single/sleepedf/user02.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '02'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_dynamic_single/sleepedf/02
|
||||
17
config/random_fixed_borda/sleepedf/user00.yaml
Normal file
17
config/random_fixed_borda/sleepedf/user00.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '00'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_borda/sleepedf/00
|
||||
17
config/random_fixed_borda/sleepedf/user01.yaml
Normal file
17
config/random_fixed_borda/sleepedf/user01.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '01'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_borda/sleepedf/01
|
||||
17
config/random_fixed_borda/sleepedf/user02.yaml
Normal file
17
config/random_fixed_borda/sleepedf/user02.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '02'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_borda/sleepedf/02
|
||||
17
config/random_fixed_sc/sleepedf/user00.yaml
Normal file
17
config/random_fixed_sc/sleepedf/user00.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '00'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_sc/sleepedf/00
|
||||
17
config/random_fixed_sc/sleepedf/user01.yaml
Normal file
17
config/random_fixed_sc/sleepedf/user01.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '01'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_sc/sleepedf/01
|
||||
17
config/random_fixed_sc/sleepedf/user02.yaml
Normal file
17
config/random_fixed_sc/sleepedf/user02.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '02'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_sc/sleepedf/02
|
||||
17
config/random_fixed_single/sleepedf/user00.yaml
Normal file
17
config/random_fixed_single/sleepedf/user00.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '00'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_single/sleepedf/00
|
||||
17
config/random_fixed_single/sleepedf/user01.yaml
Normal file
17
config/random_fixed_single/sleepedf/user01.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '01'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_single/sleepedf/01
|
||||
17
config/random_fixed_single/sleepedf/user02.yaml
Normal file
17
config/random_fixed_single/sleepedf/user02.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
update_size: 3
|
||||
borda_p: 1.0
|
||||
vocab_size: 200064
|
||||
model_paths:
|
||||
- ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: '02'
|
||||
log_path: /mnt/sting/hjyoon/projects/llm_personalization/logs/random_fixed_single/sleepedf/02
|
||||
@@ -1,32 +0,0 @@
|
||||
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
num_seeds: 10
|
||||
num_examples: 1
|
||||
|
||||
# Selection criteria: out_random | in_random | out_similar
|
||||
selection_criteria: "out_random"
|
||||
|
||||
# Embedding path (not required for random criteria)
|
||||
# embedding_path: "/home/ssum/tsllm_personalization_icl/embeddings_20users"
|
||||
|
||||
models:
|
||||
- ollama:url:joy.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:joy.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:joy.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:joy.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:joy.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:joy.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:joy.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:joy.kaist.ac.kr:11444/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
|
||||
|
||||
# chronos_base : "/mnt/sting/ssum/sleepedf_chronos_base_result"
|
||||
# out_random : "/mnt/sting/ssum/sleepedf_chronos_result_outrandom"
|
||||
# in_random : "/mnt/sting/ssum/sleepedf_chronos_result"
|
||||
log_path: "/mnt/sting/ssum/sleepedf_chronos_result"
|
||||
16
config/test.yaml
Normal file
16
config/test.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
log_path: ./temp/log
|
||||
data_path: /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new
|
||||
target_user: "00"
|
||||
queue_size: 5
|
||||
num_shot: 1
|
||||
model_paths:
|
||||
- ollama:url:rose.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11444/gpt-oss:20b
|
||||
update_size: 1
|
||||
vocab_size: 200064
|
||||
225
core/agent.py
225
core/agent.py
@@ -1,208 +1,39 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import tiktoken
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
||||
from core.logger import Logger
|
||||
from core.model import AsyncModelPool
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
model_pool,
|
||||
log_path,
|
||||
model_pool: AsyncModelPool,
|
||||
logger: Logger,
|
||||
):
|
||||
self.name = name
|
||||
self.model_pool = model_pool
|
||||
self.log_path = log_path
|
||||
self.root_log_path = log_path
|
||||
self.agent_log_path = os.path.join(log_path, name)
|
||||
os.makedirs(self.agent_log_path, exist_ok=True)
|
||||
self.logger = logger
|
||||
self.system_message = None
|
||||
|
||||
self.long_term_memory = []
|
||||
self.short_term_memory = []
|
||||
self.volatile_memory = []
|
||||
def set_system_message(self, system_message: str):
|
||||
self.system_message = system_message
|
||||
log = f"[SYSTEM]\n{system_message}\n\n"
|
||||
log_filename = "llm_log/system_prompt.txt"
|
||||
self.logger.log(log, log_filename, print_log=False)
|
||||
|
||||
self.total_input_tokens = 0
|
||||
self.total_output_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.total_calls = 0
|
||||
|
||||
def log(self, message, local=True):
|
||||
path = os.path.join(self.root_log_path, "log.txt")
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
message_type = "UNKNOWN"
|
||||
if isinstance(message, SystemMessage):
|
||||
message_type = "SYSTEM"
|
||||
if isinstance(message, HumanMessage):
|
||||
message_type = "HUMAN"
|
||||
if isinstance(message, AIMessage):
|
||||
message_type = "AI"
|
||||
content = message.content.strip()
|
||||
name = self.name
|
||||
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
|
||||
if local:
|
||||
local_path = os.path.join(self.agent_log_path, "log.txt")
|
||||
with open(local_path, "a", encoding="utf-8") as f:
|
||||
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
|
||||
|
||||
def log_tokens(self, messages, response):
|
||||
input_tokens = 0
|
||||
for msg in messages:
|
||||
msg_tokens = self.count_tokens(msg.content)
|
||||
input_tokens += msg_tokens
|
||||
self.total_tokens += msg_tokens
|
||||
self.total_input_tokens += msg_tokens
|
||||
output_tokens = self.count_tokens(response.content)
|
||||
self.total_tokens += output_tokens
|
||||
self.total_output_tokens += output_tokens
|
||||
self.total_calls += 1
|
||||
|
||||
path = os.path.join(self.agent_log_path, "tokens.txt")
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(f"Input tokens: {input_tokens}\n")
|
||||
f.write(f"Output tokens: {output_tokens}\n")
|
||||
f.write(f"Total input tokens: {self.total_input_tokens}\n")
|
||||
f.write(f"Total output tokens: {self.total_output_tokens}\n")
|
||||
f.write(f"Total tokens: {self.total_tokens}\n")
|
||||
f.write(f"Total calls: {self.total_calls}\n")
|
||||
f.write("\n")
|
||||
|
||||
def count_tokens(self, text, model="gpt-3.5-turbo"):
|
||||
enc = tiktoken.encoding_for_model(model)
|
||||
return len(enc.encode(text))
|
||||
|
||||
def update_memory(self):
|
||||
self.long_term_memory.extend(self.short_term_memory)
|
||||
self.clean_short_term_memory()
|
||||
self.clean_volatile_memory()
|
||||
|
||||
def clean_short_term_memory(self):
|
||||
self.short_term_memory = []
|
||||
|
||||
def clean_volatile_memory(self):
|
||||
self.volatile_memory = []
|
||||
|
||||
def clean_long_term_memory(self):
|
||||
self.long_term_memory = []
|
||||
|
||||
def clean_json_text(self, text):
|
||||
text = text.strip()
|
||||
text = text.replace("‘", "'").replace("’", "'")
|
||||
text = text.replace("“", "'").replace("”", "'")
|
||||
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text)
|
||||
text = re.sub(r",\s*}", "}", text)
|
||||
text = re.sub(r",\s*]", "]", text)
|
||||
text = "".join(ch for ch in text if ch.isprintable())
|
||||
text = text.replace("][", ",")
|
||||
return text
|
||||
|
||||
def safe_parse_json(self, text):
|
||||
text = text.strip()
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
elif not text.endswith("}"):
|
||||
text += "}"
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
print("[!] JSON parse failed")
|
||||
return None
|
||||
|
||||
def safe_parse_json_list(self, text):
|
||||
text = text.strip()
|
||||
match = re.search(r"\[.*\]", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
elif not text.endswith("]"):
|
||||
text += "]"
|
||||
match = re.search(r"\[.*\]", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
print("[!] JSON parse failed")
|
||||
return None
|
||||
|
||||
async def validate_response(self, response, fields, volatile=False):
|
||||
if (
|
||||
not response
|
||||
or not isinstance(response, dict)
|
||||
or not all(field in response for field in fields)
|
||||
):
|
||||
print("[!] The JSON failed to be parsed. Trying again.")
|
||||
content = (
|
||||
"Failed to parse the JSON from the previous response. Please try again."
|
||||
)
|
||||
response = await self.invoke(content, volatile=volatile)
|
||||
response = self.safe_parse_json(response)
|
||||
if (
|
||||
not response
|
||||
or not isinstance(response, dict)
|
||||
or not all(field in response for field in fields)
|
||||
):
|
||||
print("[!] Retry failed.")
|
||||
return None
|
||||
return response
|
||||
|
||||
def get_last_response(self):
|
||||
if len(self.long_term_memory) >= 2:
|
||||
last_msg = self.long_term_memory[-1]
|
||||
if isinstance(last_msg, AIMessage):
|
||||
return self.safe_parse_json(last_msg.content)
|
||||
return None
|
||||
|
||||
def set_system_message(self, content, local=True):
|
||||
system_message = SystemMessage(content=content)
|
||||
self.log(system_message, local)
|
||||
self.long_term_memory.append(system_message)
|
||||
|
||||
async def invoke(self, content, volatile=False, local=True):
|
||||
messages = self.long_term_memory.copy()
|
||||
if volatile:
|
||||
messages.extend(self.volatile_memory)
|
||||
else:
|
||||
messages.extend(self.short_term_memory)
|
||||
messages.append(HumanMessage(content=content))
|
||||
async def solve(self, content: str, sample_idx: int, example_idx: int):
|
||||
messages = []
|
||||
if self.system_message:
|
||||
messages.append({"role": "system", "content": self.system_message})
|
||||
user_message = {"role": "user", "content": content}
|
||||
messages.append(user_message)
|
||||
|
||||
try:
|
||||
response = await self.model_pool.invoke(messages)
|
||||
self.log_tokens(messages, response)
|
||||
|
||||
if volatile:
|
||||
self.volatile_memory.extend([HumanMessage(content=content), response])
|
||||
else:
|
||||
self.short_term_memory.extend([HumanMessage(content=content), response])
|
||||
local_ = not volatile and local
|
||||
self.log(HumanMessage(content=content), local=local_)
|
||||
self.log(response, local=local_)
|
||||
|
||||
return response.content.strip()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
print(f"[Error] Error occurred while invoking LLM: {e}")
|
||||
response, logprobs = await self.model_pool.invoke(messages)
|
||||
log_filename = f"llm_log/sample_{sample_idx}/example_{example_idx}.txt"
|
||||
self.logger.log(
|
||||
f"[USER]\n{content}\n\n[RESPONSE]\n{response}\n",
|
||||
log_filename,
|
||||
print_log=False,
|
||||
)
|
||||
return response, logprobs
|
||||
except Exception as e:
|
||||
self.logger.log(f"[Agent] invoke failed: {e}")
|
||||
return None, None
|
||||
|
||||
@@ -1,200 +1,81 @@
|
||||
import os
|
||||
import json
|
||||
import datasets
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .embedding_index import EmbeddingIndex
|
||||
from glob import glob
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class DataLoader:
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
user_id,
|
||||
selection_criteria="out_random",
|
||||
num_examples=1,
|
||||
embedding_index: Optional["EmbeddingIndex"] = None,
|
||||
self,
|
||||
data_path: str,
|
||||
target_user: str,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
):
|
||||
self.is_valid = False
|
||||
self.embedding_index = embedding_index
|
||||
self.data_path = data_path
|
||||
|
||||
if not os.path.exists(os.path.join(data_path, "info.json")):
|
||||
return
|
||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
|
||||
return
|
||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
|
||||
return
|
||||
|
||||
if selection_criteria in ["out_similar", "in_similar"] and embedding_index is None:
|
||||
print(f"[WARNING] {selection_criteria} requires embedding_index, falling back to random")
|
||||
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
|
||||
self.seed = seed
|
||||
self.task_metadata = self.load_task_metadata(data_path)
|
||||
self.user_metadata = self.load_user_metadata(data_path)
|
||||
self.target_dataset = self.load_target_dataset(data_path, target_user, shuffle)
|
||||
|
||||
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
|
||||
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
|
||||
self.example_dataset = datasets.Dataset.from_list([])
|
||||
users = glob(os.path.join(data_path, "*"))
|
||||
users = [path.split("/")[-1] for path in users]
|
||||
if "info.json" in users:
|
||||
users.remove("info.json")
|
||||
for user in users:
|
||||
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
|
||||
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
|
||||
all_users = glob(os.path.join(data_path, "*"))
|
||||
all_users = [os.path.basename(p) for p in all_users if os.path.isdir(p)]
|
||||
self.source_users = [u for u in all_users if u != target_user]
|
||||
self.source_dataset = self.load_source_dataset(data_path, self.source_users)
|
||||
self.classes = list(self.task_metadata["class"].keys())
|
||||
|
||||
self.test_dataset = self.test_dataset.shuffle(seed=0)
|
||||
self.example_dataset = self.example_dataset.shuffle(seed=0)
|
||||
def load_task_metadata(self, data_path: str):
|
||||
task_metadata_path = os.path.join(data_path, "task_metadata.json")
|
||||
with open(task_metadata_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
self.user_id = user_id
|
||||
self.selection_criteria = selection_criteria
|
||||
self.num_examples = num_examples
|
||||
def load_user_metadata(self, data_path: str):
|
||||
user_metadata_path = os.path.join(data_path, "user_metadata.json")
|
||||
with open(user_metadata_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
self.classes = sorted(list(self.metadata["class"].keys()))
|
||||
|
||||
# Build lookup index for fast example retrieval: (user_id, idx) -> dataset_index
|
||||
self._example_lookup = {}
|
||||
for i, example in enumerate(self.example_dataset):
|
||||
key = (str(example["user_id"]), int(example["idx"]))
|
||||
self._example_lookup[key] = i
|
||||
|
||||
if selection_criteria in ["out_similar", "in_similar"]:
|
||||
self.selected_examples = None
|
||||
else:
|
||||
self.selected_examples = self.sample_examples()
|
||||
if self.selected_examples is None:
|
||||
return
|
||||
def load_target_dataset(
|
||||
self, data_path: str, target_user: str, shuffle: bool = False
|
||||
):
|
||||
target_dataset_path = os.path.join(data_path, target_user)
|
||||
target_dataset = datasets.load_from_disk(target_dataset_path)
|
||||
if shuffle:
|
||||
return target_dataset.shuffle(seed=self.seed)
|
||||
return target_dataset
|
||||
|
||||
self.is_valid = True
|
||||
def load_source_dataset(self, data_path: str, source_users: List[str]):
|
||||
source_dataset = datasets.Dataset.from_list([])
|
||||
for user in source_users:
|
||||
user_dataset = datasets.load_from_disk(os.path.join(data_path, user))
|
||||
source_dataset = datasets.concatenate_datasets(
|
||||
[source_dataset, user_dataset]
|
||||
)
|
||||
source_dataset = source_dataset.shuffle(seed=self.seed)
|
||||
return source_dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.test_dataset)
|
||||
return len(self.target_dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.test_dataset[idx]
|
||||
if self.selection_criteria in ["out_similar", "in_similar"]:
|
||||
examples = self.sample_similar_examples(sample)
|
||||
else:
|
||||
examples = self.selected_examples
|
||||
return sample, examples
|
||||
def __getitem__(self, idx: int):
|
||||
return self.target_dataset[idx]
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.test_dataset:
|
||||
if self.selection_criteria in ["out_similar", "in_similar"]:
|
||||
examples = self.sample_similar_examples(sample)
|
||||
else:
|
||||
examples = self.selected_examples
|
||||
yield sample, examples
|
||||
|
||||
def sample_similar_examples(self, sample):
|
||||
"""
|
||||
Sample examples based on embedding similarity to the given sample.
|
||||
|
||||
For each class, finds the most similar example from the example dataset
|
||||
using Chronos-2 embeddings.
|
||||
|
||||
Args:
|
||||
sample: The test sample to find similar examples for
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset containing similar examples (one per class)
|
||||
"""
|
||||
if self.embedding_index is None:
|
||||
return self.sample_examples()
|
||||
|
||||
query_embedding = self.embedding_index.get_embedding_by_key(
|
||||
user_id=str(sample["user_id"]),
|
||||
session_id=str(sample["session_id"]),
|
||||
idx=int(sample["idx"]),
|
||||
)
|
||||
|
||||
if query_embedding is None:
|
||||
print(f"[WARNING] No embedding found for sample {sample['user_id']}/{sample['session_id']}/{sample['idx']}")
|
||||
return self.sample_examples()
|
||||
|
||||
if self.selection_criteria == "out_similar":
|
||||
exclude_user = str(self.user_id)
|
||||
include_user = None
|
||||
else:
|
||||
exclude_user = None
|
||||
include_user = str(self.user_id)
|
||||
|
||||
similar_per_class = self.embedding_index.find_similar_per_class(
|
||||
query_embedding=query_embedding,
|
||||
classes=self.classes,
|
||||
k_per_class=self.num_examples,
|
||||
exclude_user=exclude_user,
|
||||
include_user=include_user,
|
||||
filter_session="1",
|
||||
)
|
||||
|
||||
example_list = []
|
||||
for cls, similar_samples in similar_per_class.items():
|
||||
for global_idx, similarity, metadata in similar_samples:
|
||||
example = self._find_example_by_metadata(metadata)
|
||||
if example is not None:
|
||||
example_list.append(example)
|
||||
|
||||
if len(example_list) == 0:
|
||||
print(f"[WARNING] No similar examples found, falling back to random")
|
||||
return self.sample_examples()
|
||||
|
||||
return datasets.Dataset.from_list(example_list)
|
||||
|
||||
def _find_example_by_metadata(self, metadata):
|
||||
"""Find an example in the example_dataset by its metadata using O(1) lookup."""
|
||||
user_id = str(metadata["user_id"])
|
||||
idx = int(metadata["idx"])
|
||||
|
||||
key = (user_id, idx)
|
||||
if key not in self._example_lookup:
|
||||
return None
|
||||
|
||||
dataset_idx = self._example_lookup[key]
|
||||
example = self.example_dataset[dataset_idx]
|
||||
return {
|
||||
"user_id": example["user_id"],
|
||||
"session_id": example["session_id"],
|
||||
"idx": example["idx"],
|
||||
"label": example["label"],
|
||||
"features": example["features"],
|
||||
"data": example.get("data", {}),
|
||||
}
|
||||
yield from self.target_dataset
|
||||
|
||||
def sample_examples(self):
|
||||
example_dataset = datasets.Dataset.from_list([])
|
||||
if self.selection_criteria == "out_random":
|
||||
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] != user_id)
|
||||
for c in self.classes:
|
||||
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
|
||||
if len(class_dataset) < self.num_examples:
|
||||
return None
|
||||
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
|
||||
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
|
||||
elif self.selection_criteria == "in_random":
|
||||
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] == user_id)
|
||||
for c in self.classes:
|
||||
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
|
||||
if len(class_dataset) < self.num_examples:
|
||||
return None
|
||||
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
|
||||
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
|
||||
return example_dataset
|
||||
def get_source_dataset(self):
|
||||
return self.source_dataset
|
||||
|
||||
def get_metadata(self):
|
||||
return self.metadata
|
||||
def get_task_metadata(self):
|
||||
return self.task_metadata
|
||||
|
||||
def get_sensor_info(self):
|
||||
return self.metadata["feature"]
|
||||
def get_user_metadata(self, user: Optional[str] = None):
|
||||
if user is None:
|
||||
return self.user_metadata
|
||||
return self.user_metadata[user]
|
||||
|
||||
def get_task_info(self):
|
||||
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
|
||||
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
|
||||
classes_info = "\n".join(classes_info)
|
||||
task_info += f"**Classes**:\n{classes_info}"
|
||||
return task_info
|
||||
def get_source_users(self):
|
||||
return self.source_users
|
||||
|
||||
def get_classes_info(self):
|
||||
classes_info = [k for k in self.metadata["class"].keys()]
|
||||
return classes_info
|
||||
def get_classes(self):
|
||||
return self.classes
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
"""
|
||||
Embedding Index for Similarity-based Example Selection
|
||||
|
||||
This module provides functionality to:
|
||||
1. Load pre-computed Chronos-2 embeddings
|
||||
2. Build an index for fast nearest neighbor search
|
||||
3. Find similar examples based on embedding distance
|
||||
|
||||
The embedding index enables ICL (In-Context Learning) example selection
|
||||
based on semantic similarity rather than random sampling.
|
||||
|
||||
Similarity Strategies:
|
||||
- out_similar: Find similar examples from OTHER users (cross-user transfer)
|
||||
- in_similar: Find similar examples from SAME user (personalization)
|
||||
|
||||
Usage:
|
||||
index = EmbeddingIndex(embedding_path)
|
||||
similar_indices = index.find_similar(query_embedding, k=5, exclude_user="01")
|
||||
|
||||
Author: NMSL Research Team
|
||||
Date: 2026-01-16
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
from datasets import load_from_disk, Dataset
|
||||
|
||||
|
||||
class EmbeddingIndex:
|
||||
"""
|
||||
Index for fast similarity search over pre-computed embeddings.
|
||||
|
||||
Uses cosine similarity for finding nearest neighbors in embedding space.
|
||||
Supports filtering by user_id and session_id for controlled experiments.
|
||||
|
||||
Attributes:
|
||||
embeddings: numpy array of shape (num_samples, embedding_dim)
|
||||
user_ids: list of user identifiers for each sample
|
||||
session_ids: list of session identifiers for each sample
|
||||
labels: list of sleep stage labels for each sample
|
||||
indices: list of original indices in the dataset
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_path: str):
|
||||
"""
|
||||
Initialize the embedding index from a HuggingFace dataset.
|
||||
|
||||
Args:
|
||||
embedding_path: Path to directory containing saved embeddings dataset
|
||||
(output from gen_plot.py extract command)
|
||||
"""
|
||||
if not os.path.isdir(embedding_path):
|
||||
raise FileNotFoundError(
|
||||
f"Embedding directory not found: {embedding_path}. "
|
||||
"Run 'python gen_plot.py extract' first to generate embeddings."
|
||||
)
|
||||
|
||||
print(f"[EmbeddingIndex] Loading embeddings from: {embedding_path}")
|
||||
self.dataset = load_from_disk(embedding_path)
|
||||
|
||||
# Extract arrays for fast access
|
||||
self.embeddings = np.array(self.dataset["embedding"], dtype=np.float32)
|
||||
self.user_ids = np.array([str(uid) for uid in self.dataset["user_id"]])
|
||||
self.session_ids = np.array([str(sid) for sid in self.dataset["session_id"]])
|
||||
self.labels = np.array([str(label) for label in self.dataset["label"]])
|
||||
self.indices = np.array(self.dataset["idx"])
|
||||
|
||||
# Normalize embeddings for cosine similarity (dot product of unit vectors)
|
||||
self.embeddings_normalized = self._normalize(self.embeddings)
|
||||
|
||||
# Build lookup indices for fast filtering
|
||||
self._build_lookup_indices()
|
||||
|
||||
print(f"[EmbeddingIndex] Loaded {len(self.embeddings)} samples")
|
||||
print(f"[EmbeddingIndex] Embedding dimension: {self.embeddings.shape[1]}")
|
||||
print(f"[EmbeddingIndex] Unique users: {len(np.unique(self.user_ids))}")
|
||||
print(f"[EmbeddingIndex] Unique sessions: {len(np.unique(self.session_ids))}")
|
||||
|
||||
def _normalize(self, vectors: np.ndarray) -> np.ndarray:
|
||||
"""L2 normalize vectors for cosine similarity computation."""
|
||||
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||||
norms = np.where(norms == 0, 1, norms) # Avoid division by zero
|
||||
return vectors / norms
|
||||
|
||||
def _build_lookup_indices(self):
|
||||
"""Build dictionaries for fast filtering by user/session/label."""
|
||||
self.user_to_indices: Dict[str, np.ndarray] = {}
|
||||
self.session_to_indices: Dict[str, np.ndarray] = {}
|
||||
self.label_to_indices: Dict[str, np.ndarray] = {}
|
||||
|
||||
for user_id in np.unique(self.user_ids):
|
||||
self.user_to_indices[user_id] = np.where(self.user_ids == user_id)[0]
|
||||
|
||||
for session_id in np.unique(self.session_ids):
|
||||
self.session_to_indices[session_id] = np.where(self.session_ids == session_id)[0]
|
||||
|
||||
for label in np.unique(self.labels):
|
||||
self.label_to_indices[label] = np.where(self.labels == label)[0]
|
||||
|
||||
def get_embedding_by_key(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
idx: int
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get embedding for a specific sample identified by (user_id, session_id, idx).
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
session_id: Session identifier (1 or 2)
|
||||
idx: Sample index within the session
|
||||
|
||||
Returns:
|
||||
Embedding vector or None if not found
|
||||
"""
|
||||
mask = (
|
||||
(self.user_ids == str(user_id)) &
|
||||
(self.session_ids == str(session_id)) &
|
||||
(self.indices == idx)
|
||||
)
|
||||
matches = np.where(mask)[0]
|
||||
|
||||
if len(matches) == 0:
|
||||
return None
|
||||
return self.embeddings[matches[0]]
|
||||
|
||||
def cosine_similarity(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
candidates: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute cosine similarity between query and candidate vectors.
|
||||
|
||||
Args:
|
||||
query: Query vector of shape (embedding_dim,)
|
||||
candidates: Candidate matrix of shape (num_candidates, embedding_dim)
|
||||
|
||||
Returns:
|
||||
Similarity scores of shape (num_candidates,)
|
||||
"""
|
||||
query_normalized = query / (np.linalg.norm(query) + 1e-8)
|
||||
return candidates @ query_normalized
|
||||
|
||||
def find_similar(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
k: int = 5,
|
||||
exclude_user: Optional[str] = None,
|
||||
include_user: Optional[str] = None,
|
||||
filter_session: Optional[str] = None,
|
||||
filter_label: Optional[str] = None,
|
||||
) -> List[Tuple[int, float, Dict[str, Any]]]:
|
||||
"""
|
||||
Find k most similar samples to the query embedding.
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector of shape (embedding_dim,)
|
||||
k: Number of nearest neighbors to return
|
||||
exclude_user: User ID to exclude from search (for out_similar)
|
||||
include_user: User ID to include only (for in_similar)
|
||||
filter_session: Only search in this session (e.g., "1" for train set)
|
||||
filter_label: Only return samples with this label
|
||||
|
||||
Returns:
|
||||
List of (index, similarity, metadata) tuples sorted by similarity (descending)
|
||||
metadata contains: user_id, session_id, idx, label
|
||||
"""
|
||||
# Build candidate mask based on filters
|
||||
candidate_mask = np.ones(len(self.embeddings), dtype=bool)
|
||||
|
||||
if exclude_user is not None:
|
||||
candidate_mask &= (self.user_ids != str(exclude_user))
|
||||
|
||||
if include_user is not None:
|
||||
candidate_mask &= (self.user_ids == str(include_user))
|
||||
|
||||
if filter_session is not None:
|
||||
candidate_mask &= (self.session_ids == str(filter_session))
|
||||
|
||||
if filter_label is not None:
|
||||
candidate_mask &= (self.labels == str(filter_label))
|
||||
|
||||
candidate_indices = np.where(candidate_mask)[0]
|
||||
|
||||
if len(candidate_indices) == 0:
|
||||
return []
|
||||
|
||||
# Compute similarities
|
||||
candidate_embeddings = self.embeddings_normalized[candidate_indices]
|
||||
similarities = self.cosine_similarity(query_embedding, candidate_embeddings)
|
||||
|
||||
# Get top-k
|
||||
k = min(k, len(similarities))
|
||||
top_k_local = np.argsort(similarities)[::-1][:k]
|
||||
|
||||
results = []
|
||||
for local_idx in top_k_local:
|
||||
global_idx = candidate_indices[local_idx]
|
||||
sim = similarities[local_idx]
|
||||
metadata = {
|
||||
"user_id": self.user_ids[global_idx],
|
||||
"session_id": self.session_ids[global_idx],
|
||||
"idx": int(self.indices[global_idx]),
|
||||
"label": self.labels[global_idx],
|
||||
}
|
||||
results.append((global_idx, float(sim), metadata))
|
||||
|
||||
return results
|
||||
|
||||
def find_similar_per_class(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
classes: List[str],
|
||||
k_per_class: int = 1,
|
||||
exclude_user: Optional[str] = None,
|
||||
include_user: Optional[str] = None,
|
||||
filter_session: Optional[str] = "1",
|
||||
) -> Dict[str, List[Tuple[int, float, Dict[str, Any]]]]:
|
||||
"""
|
||||
Find k most similar samples for each class.
|
||||
|
||||
This ensures balanced class representation in ICL examples,
|
||||
similar to the original random sampling approach.
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector
|
||||
classes: List of class labels to search
|
||||
k_per_class: Number of examples per class
|
||||
exclude_user: User to exclude (out_similar mode)
|
||||
include_user: User to include only (in_similar mode)
|
||||
filter_session: Session to search in (default: "1" = train set)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping class label to list of similar samples
|
||||
"""
|
||||
results = {}
|
||||
for cls in classes:
|
||||
similar = self.find_similar(
|
||||
query_embedding=query_embedding,
|
||||
k=k_per_class,
|
||||
exclude_user=exclude_user,
|
||||
include_user=include_user,
|
||||
filter_session=filter_session,
|
||||
filter_label=cls,
|
||||
)
|
||||
results[cls] = similar
|
||||
return results
|
||||
|
||||
def get_sample_metadata(self, global_idx: int) -> Dict[str, Any]:
|
||||
"""Get metadata for a sample by its global index."""
|
||||
return {
|
||||
"user_id": self.user_ids[global_idx],
|
||||
"session_id": self.session_ids[global_idx],
|
||||
"idx": int(self.indices[global_idx]),
|
||||
"label": self.labels[global_idx],
|
||||
"embedding": self.embeddings[global_idx],
|
||||
}
|
||||
|
||||
|
||||
def create_embedding_index(embedding_path: str) -> Optional[EmbeddingIndex]:
|
||||
"""
|
||||
Factory function to create an EmbeddingIndex, returning None if path doesn't exist.
|
||||
|
||||
Args:
|
||||
embedding_path: Path to embeddings directory
|
||||
|
||||
Returns:
|
||||
EmbeddingIndex instance or None
|
||||
"""
|
||||
if not os.path.isdir(embedding_path):
|
||||
print(f"[WARNING] Embedding path not found: {embedding_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return EmbeddingIndex(embedding_path)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to load embeddings: {e}")
|
||||
return None
|
||||
37
core/example_queue.py
Normal file
37
core/example_queue.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from core.logger import Logger
|
||||
|
||||
|
||||
class ExampleQueue:
|
||||
def __init__(self, queue_size: int, logger: Logger):
|
||||
self.queue_size = queue_size
|
||||
self.logger = logger
|
||||
self.queue: List[Dict[str, Any]] = []
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.queue
|
||||
|
||||
def update(self, results: List[Dict[str, Any]], new_examples: List[Dict[str, Any]]):
|
||||
n_replace = len(new_examples)
|
||||
if n_replace == 0:
|
||||
return
|
||||
|
||||
if results:
|
||||
ranked = sorted(range(len(results)), key=lambda i: results[i]["score"])
|
||||
drop_indices = sorted(ranked[:n_replace], reverse=True)
|
||||
for idx in drop_indices:
|
||||
self.logger.log(
|
||||
f"[Queue] Dropping index {idx} (score={results[idx]['score']:.4f})"
|
||||
)
|
||||
self.queue.pop(idx)
|
||||
|
||||
slots = self.queue_size - len(self.queue)
|
||||
if slots < len(new_examples):
|
||||
self.logger.log(
|
||||
f"[Queue] Capping: {len(new_examples)} new but only {slots} slot(s) free"
|
||||
)
|
||||
new_examples = new_examples[:slots]
|
||||
|
||||
self.queue.extend(new_examples)
|
||||
self.logger.log(f"[Queue] Updated, queue size = {len(self.queue)}")
|
||||
29
core/json_utils.py
Normal file
29
core/json_utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import re
|
||||
import json
|
||||
|
||||
|
||||
def clean_json_text(text: str):
|
||||
text = text.strip()
|
||||
text = text.replace("\u2018", "'").replace("\u2019", "'") # smart single quotes
|
||||
text = text.replace("\u201c", '"').replace("\u201d", '"') # smart double quotes
|
||||
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text) # escape stray backslashes
|
||||
text = re.sub(r",\s*}", "}", text) # trailing comma in object
|
||||
text = re.sub(r",\s*]", "]", text) # trailing comma in array
|
||||
text = "".join(ch for ch in text if ch.isprintable()) # strip control chars
|
||||
text = text.replace("][", ",") # merge adjacent arrays
|
||||
return text
|
||||
|
||||
|
||||
def safe_parse_json(text: str):
|
||||
if not text:
|
||||
return None
|
||||
text = text.strip()
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if not match and not text.endswith("}"):
|
||||
match = re.search(r"\{.*\}", text + "}", re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(clean_json_text(match.group(0)))
|
||||
except json.JSONDecodeError as e:
|
||||
return None
|
||||
return None
|
||||
58
core/logger.py
Normal file
58
core/logger.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
import yaml
|
||||
|
||||
from datetime import datetime
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, log_path: str):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.log_path = f"{log_path}_{timestamp}"
|
||||
os.makedirs(self.log_path, exist_ok=True)
|
||||
self.answers = []
|
||||
|
||||
def log(self, message: str, filename: str = "log.txt", print_log: bool = True):
|
||||
if print_log:
|
||||
print(message)
|
||||
log_file_path = os.path.join(self.log_path, filename)
|
||||
base_dir = os.path.dirname(log_file_path)
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
with open(log_file_path, "a", encoding="utf-8") as f:
|
||||
f.write(message + "\n")
|
||||
|
||||
def log_config(self, config: dict):
|
||||
message = yaml.dump(config, default_flow_style=False, sort_keys=False).strip()
|
||||
self.log(message, "config.yaml", print_log=False)
|
||||
|
||||
def log_result(self, idx: int, answer: str, ground_truth: str):
|
||||
self.answers.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"answer": answer,
|
||||
"ground_truth": ground_truth,
|
||||
}
|
||||
)
|
||||
self.log(
|
||||
f"[RESULT] {idx}: answer={answer}, ground_truth={ground_truth}",
|
||||
"result.txt",
|
||||
)
|
||||
|
||||
def report(self, elapsed_seconds: float = None):
|
||||
if not self.answers:
|
||||
self.log("[REPORT] No valid answers recorded", "report.txt")
|
||||
return
|
||||
answers = [a["answer"] for a in self.answers]
|
||||
ground_truths = [a["ground_truth"] for a in self.answers]
|
||||
accuracy = accuracy_score(ground_truths, answers)
|
||||
f1 = f1_score(ground_truths, answers, average="macro")
|
||||
n = len(self.answers)
|
||||
time_str = ""
|
||||
if elapsed_seconds is not None:
|
||||
m, s = divmod(int(elapsed_seconds), 60)
|
||||
h, m = divmod(m, 60)
|
||||
time_str = f", time={h}h{m:02d}m{s:02d}s"
|
||||
self.log(
|
||||
f"[REPORT] accuracy={accuracy:.4f}, f1={f1:.4f}, n={n}{time_str}",
|
||||
"report.txt",
|
||||
)
|
||||
182
core/model.py
182
core/model.py
@@ -1,95 +1,133 @@
|
||||
import asyncio
|
||||
import requests
|
||||
import numpy as np
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
|
||||
def load_models(models):
|
||||
model_pool = AsyncModelPool()
|
||||
for model in models:
|
||||
model_pool.add_model(Model(model))
|
||||
model_pool.init_models()
|
||||
return model_pool
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, model, temperature=0.7):
|
||||
if model.startswith("ollama:"):
|
||||
model = model.replace("ollama:", "")
|
||||
if "url:" in model:
|
||||
model = model.replace("url:", "")
|
||||
base_url = model.split("/")[0]
|
||||
if not base_url.startswith("http"):
|
||||
base_url = "http://" + base_url
|
||||
model_type = model.split("/")[1]
|
||||
self.model = ChatOllama(
|
||||
model=model_type,
|
||||
base_url=base_url,
|
||||
temperature=temperature,
|
||||
num_ctx=12000,
|
||||
)
|
||||
else:
|
||||
self.model = ChatOllama(
|
||||
model=model.replace("ollama:", ""),
|
||||
temperature=temperature,
|
||||
num_ctx=12000,
|
||||
)
|
||||
else:
|
||||
self.model = init_chat_model(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
temperature: float = 0.0,
|
||||
num_ctx: int = 131072,
|
||||
max_tokens: int = -1,
|
||||
logprobs: bool = True,
|
||||
top_logprobs: int = 20,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 0,
|
||||
stream: bool = False,
|
||||
think: str = "low",
|
||||
):
|
||||
self.backend = None
|
||||
|
||||
def invoke(self, messages):
|
||||
if model_path.startswith("ollama:"):
|
||||
raw = model_path.split("url:")[1]
|
||||
self.backend = "ollama"
|
||||
self.base_url = f"http://{raw.split('/')[0]}"
|
||||
self.model_name = raw.split("/")[1]
|
||||
self.temperature = temperature
|
||||
self.num_ctx = num_ctx
|
||||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
self.think = think
|
||||
else:
|
||||
raise ValueError(f"Unknown model prefix: {model_path}")
|
||||
|
||||
def invoke(self, messages: List[Dict[str, str]]):
|
||||
try:
|
||||
response = self.model.invoke(messages)
|
||||
return response
|
||||
return self._invoke_ollama(messages)
|
||||
except Exception as e:
|
||||
print(f"[Error] Error occurred while invoking LLM: {e}")
|
||||
return e
|
||||
print(f"[Error] invoke failed: {e}")
|
||||
return "", np.array([])
|
||||
|
||||
def _invoke_ollama(self, messages: List[Dict[str, str]]):
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/api/chat",
|
||||
json={
|
||||
"messages": messages,
|
||||
"model": self.model_name,
|
||||
"temperature": self.temperature,
|
||||
"num_ctx": self.num_ctx,
|
||||
"num_predict": self.max_tokens,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"stream": self.stream,
|
||||
"think": self.think,
|
||||
},
|
||||
timeout=300,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
response = data["message"]["content"]
|
||||
|
||||
logprobs = []
|
||||
for lp in data.get("logprobs", []):
|
||||
logprobs.append([tp["logprob"] for tp in lp.get("top_logprobs", [])])
|
||||
|
||||
return response, np.array(logprobs) if logprobs else np.array([])
|
||||
|
||||
|
||||
class AsyncModel:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model: Model):
|
||||
self.model = model
|
||||
|
||||
async def invoke(self, content):
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
async def invoke(self, messages: List[Dict[str, str]]):
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.invoke(content),
|
||||
lambda: self.model.invoke(messages),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class AsyncModelPool:
|
||||
def __init__(self):
|
||||
self.models = []
|
||||
self._available_models = None
|
||||
self._model_semaphore = None
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def add_model(self, model):
|
||||
self.models.append(model)
|
||||
def add_model(self, model: Model):
|
||||
self._queue.put_nowait(AsyncModel(model))
|
||||
|
||||
def init_models(self):
|
||||
# Initialize the queue and semaphore in the current event loop
|
||||
self._available_models = asyncio.Queue()
|
||||
for model in self.models:
|
||||
async_model = AsyncModel(model)
|
||||
self._available_models.put_nowait(async_model)
|
||||
self._model_semaphore = asyncio.Semaphore(len(self.models))
|
||||
|
||||
def warmup(self):
|
||||
for model in self.models:
|
||||
model.invoke([HumanMessage(content="Hello world!")])
|
||||
|
||||
async def invoke(self, content):
|
||||
if self._available_models is None:
|
||||
raise RuntimeError("Model pool not initialized. Call init_models() first.")
|
||||
async_model = await self._available_models.get()
|
||||
async def invoke(self, messages: List[Dict[str, str]]):
|
||||
async_model = await self._queue.get()
|
||||
try:
|
||||
response = await async_model.invoke(content)
|
||||
return response
|
||||
return await async_model.invoke(messages)
|
||||
finally:
|
||||
self._available_models.put_nowait(async_model)
|
||||
self._queue.put_nowait(async_model)
|
||||
|
||||
|
||||
def load_models(
|
||||
model_paths: List[str],
|
||||
temperature: float = 0.0,
|
||||
num_ctx: int = 131072,
|
||||
max_tokens: int = -1,
|
||||
logprobs: bool = True,
|
||||
top_logprobs: int = 20,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 0,
|
||||
stream: bool = False,
|
||||
think: str = "low",
|
||||
):
|
||||
pool = AsyncModelPool()
|
||||
for path in model_paths:
|
||||
pool.add_model(
|
||||
Model(
|
||||
path,
|
||||
temperature=temperature,
|
||||
num_ctx=num_ctx,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
think=think,
|
||||
)
|
||||
)
|
||||
print(f"[ModelPool] Loaded a model from {path}")
|
||||
return pool
|
||||
|
||||
66
core/prompt.py
Normal file
66
core/prompt.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def gen_system_message(metadata: Dict[str, Any]):
|
||||
task_info = metadata["task"]
|
||||
classes_info = [f" - {k}: {v}" for k, v in metadata["class"].items()]
|
||||
classes_info = "\n".join(classes_info)
|
||||
data_info = metadata["data"]
|
||||
feature_info = metadata["feature"]
|
||||
system_message = (
|
||||
f"You are an assistant who interprets sensor data to solve a task.\n\n"
|
||||
f"1. Task:\n"
|
||||
f"{task_info}\n\n"
|
||||
f"2. Classes:\n"
|
||||
f"{classes_info}\n\n"
|
||||
f"3. Data:\n"
|
||||
f"{data_info}\n\n"
|
||||
f"4. Features:\n"
|
||||
f"{feature_info}\n\n"
|
||||
"Your goal is to analyze the sensor data and "
|
||||
"provide a reasoned answer for the task.\n"
|
||||
"Do not output analysis."
|
||||
)
|
||||
return system_message
|
||||
|
||||
|
||||
def gen_task_message(
|
||||
sample: Dict[str, Any],
|
||||
example_set: List[Dict[str, Any]],
|
||||
):
|
||||
def format_feature(value: Any):
|
||||
if isinstance(value, float):
|
||||
if abs(value) >= 1e4 or (abs(value) < 1e-2 and value != 0):
|
||||
return f"{value:.2e}"
|
||||
return f"{value:.2f}"
|
||||
return str(value)
|
||||
|
||||
example_info = ""
|
||||
for cls, examples in example_set.items():
|
||||
for i, example in enumerate(examples):
|
||||
example_info += f"Example {i+1} of {cls}:\n"
|
||||
for k, v in example["features"].items():
|
||||
example_info += f" - {k}: {format_feature(v)}\n"
|
||||
example_info += "\n"
|
||||
example_info = example_info.strip()
|
||||
|
||||
test_info = "Current sensor data:\n"
|
||||
for k, v in sample["features"].items():
|
||||
test_info += f" - {k}: {format_feature(v)}\n"
|
||||
test_info = test_info.strip()
|
||||
|
||||
classes = list(example_set.keys())
|
||||
task_message = (
|
||||
"You have a few labeled examples of sensor data:\n"
|
||||
f"{example_info}\n\n"
|
||||
f"And you have the current sensor data:\n"
|
||||
f"{test_info}\n\n"
|
||||
f"Please provide your answer among {classes} "
|
||||
"and the reasoning for your answer.\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
' "REASON": "<reasoning for the answer>",\n'
|
||||
f' "ANSWER": "<answer among {classes}>"\n'
|
||||
"}"
|
||||
)
|
||||
return task_message
|
||||
52
core/recruiter.py
Normal file
52
core/recruiter.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from datasets import Dataset
|
||||
from typing import Any, Dict, List
|
||||
from core.logger import Logger
|
||||
|
||||
|
||||
class Recruiter:
|
||||
def __init__(
|
||||
self,
|
||||
source_dataset: Dataset,
|
||||
source_users: List[str],
|
||||
classes: List[str],
|
||||
num_shot: int,
|
||||
logger: Logger,
|
||||
):
|
||||
self.source_dataset = source_dataset
|
||||
self.num_shot = num_shot
|
||||
self.source_users = source_users
|
||||
self.classes = classes
|
||||
self.logger = logger
|
||||
|
||||
def recruit(self, num_example_set: int):
|
||||
# place holder for the recruiter strategy
|
||||
self.logger.log(f"[Recruiter] Recruiting {num_example_set} example set(s)")
|
||||
# randomly select examples from the user dataset
|
||||
list_example_sets = []
|
||||
for _ in range(num_example_set):
|
||||
# randomly select user
|
||||
user = random.choice(self.source_users)
|
||||
user_dataset = self.source_dataset.filter(lambda x: x["user_id"] == user)
|
||||
example_set = {}
|
||||
for cls in self.classes:
|
||||
cls_dataset = user_dataset.filter(lambda x: x["label"] == cls)
|
||||
if len(cls_dataset) < self.num_shot:
|
||||
raise ValueError(
|
||||
f"Not enough examples for class {cls} in user {user}"
|
||||
)
|
||||
random_index = np.random.choice(
|
||||
len(cls_dataset), self.num_shot, replace=False
|
||||
)
|
||||
example_set[cls] = cls_dataset.select(random_index)
|
||||
self.logger.log(
|
||||
f"[Recruiter] Recruited {self.num_shot} example(s) from user {user}"
|
||||
)
|
||||
list_example_sets.append(example_set)
|
||||
return list_example_sets
|
||||
|
||||
def update_strategy(self, results: List[Dict[str, Any]]):
|
||||
# TODO: implement adaptive recruitment strategy based on results
|
||||
pass
|
||||
27
core/scores.py
Normal file
27
core/scores.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def self_certainty(logprobs: np.ndarray, vocab_size: int) -> float:
|
||||
"""Returns -inf if logprobs is None or empty."""
|
||||
if logprobs is None or logprobs.size == 0:
|
||||
return float("-inf")
|
||||
|
||||
# probability mass in top-k
|
||||
probs = np.exp(logprobs)
|
||||
k = probs.shape[-1]
|
||||
topk_sum = probs.sum(axis=-1)
|
||||
# remaining probability mass
|
||||
tail_mass = 1.0 - topk_sum
|
||||
tail_mass = np.clip(tail_mass, 1e-12, None)
|
||||
# uniform distribution over remaining tokens
|
||||
tail_prob = tail_mass / (vocab_size - k)
|
||||
|
||||
# sum log probabilities
|
||||
logprob_sum_topk = logprobs.sum(axis=-1)
|
||||
logprob_sum_tail = (vocab_size - k) * np.log(tail_prob)
|
||||
logprob_sum = logprob_sum_topk + logprob_sum_tail
|
||||
|
||||
# self-certainty score
|
||||
score = (-1.0 / vocab_size) * logprob_sum - math.log(vocab_size)
|
||||
return float(np.mean(score))
|
||||
@@ -1,181 +0,0 @@
|
||||
import json
|
||||
import copy
|
||||
import os
|
||||
from .agent import Agent
|
||||
|
||||
|
||||
class SensingAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
model_pool,
|
||||
task_info,
|
||||
classes_info,
|
||||
sensor_info,
|
||||
sample,
|
||||
examples,
|
||||
log_path,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
model_pool=model_pool,
|
||||
log_path=log_path,
|
||||
)
|
||||
|
||||
self.task_info = task_info
|
||||
self.classes_info = classes_info
|
||||
self.sensor_info = sensor_info
|
||||
self.sample = sample
|
||||
self.examples = examples
|
||||
self.init_system_message()
|
||||
|
||||
def init_system_message(self):
|
||||
content = (
|
||||
f"You are {self.name} agent that interprets sensor data to solve a task.\n"
|
||||
"You have the following information about the task:\n"
|
||||
f"{self.task_info}\n\n"
|
||||
"You have the following information about the sensor data:\n"
|
||||
f"{self.sensor_info}\n\n"
|
||||
"Your goal is to analyze the features and "
|
||||
"provide a reasoned answer using your knowledge."
|
||||
)
|
||||
self.set_system_message(content)
|
||||
|
||||
def gen_feature_info(self):
|
||||
feature_info = f"{self.name} features:\n"
|
||||
if len(self.examples) > 0:
|
||||
feature_info += f"{self.gen_example_info()}\n\n"
|
||||
feature_info += "**Current sample features**:\n"
|
||||
for k, v in self.sample["features"].items():
|
||||
feature_info += f" - {k}: {self.format_feature(v)}\n"
|
||||
feature_info = feature_info.strip()
|
||||
return feature_info
|
||||
|
||||
def gen_example_info(self):
|
||||
example_info = (
|
||||
"**Examples**\n"
|
||||
"Sensor values might not always align with your inherent "
|
||||
"knowledge due to differences in data collection or processing. "
|
||||
"So, we included a few labeled examples to help your interpretation:\n"
|
||||
)
|
||||
for example in self.examples:
|
||||
example_info += f"*Example of {example['label']}*:\n"
|
||||
for k, v in example["features"].items():
|
||||
example_info += f" - {k}: {self.format_feature(v)}\n"
|
||||
example_info += "\n"
|
||||
example_info = example_info.strip()
|
||||
return example_info
|
||||
|
||||
def format_feature(self, value):
|
||||
if isinstance(value, float):
|
||||
if abs(value) >= 1e4 or abs(value) < 1e-2:
|
||||
return f"{value:.2e}"
|
||||
return f"{value:.2f}"
|
||||
return value
|
||||
|
||||
def log_summary(self, message, print_log=True):
|
||||
path = os.path.join(self.log_path, "summary.txt")
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(f"{message}\n")
|
||||
if print_log:
|
||||
print(message)
|
||||
|
||||
async def solve(self, sample, examples, ground_truth):
|
||||
self.sample = sample
|
||||
self.examples = examples
|
||||
feature_info = self.gen_feature_info()
|
||||
content = (
|
||||
f"You have received sensor features from {self.name} modality:\n"
|
||||
f"{feature_info}\n\n"
|
||||
f"Please provide your answer for the task among {self.classes_info} "
|
||||
"and the reasoning for your answer.\n"
|
||||
"Note that the sensor features might be wrong due to the data collection or processing.\n"
|
||||
"You can evaluate the quality of the features by checking the examples you have.\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
' "REASON": "<Reasoning for the answer>",\n'
|
||||
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
|
||||
"}\n\n"
|
||||
"Do not include any additional text outside the JSON."
|
||||
)
|
||||
|
||||
response = await self.invoke(content)
|
||||
parsed_response = self.safe_parse_json(response)
|
||||
parsed_response = await self.validate_response(
|
||||
parsed_response, ["REASON", "ANSWER"]
|
||||
)
|
||||
self.clean_short_term_memory()
|
||||
self.clean_long_term_memory()
|
||||
answer = parsed_response["ANSWER"]
|
||||
self.log_summary(f"Answer: {answer} (Ground truth: {ground_truth})", print_log=True)
|
||||
return parsed_response
|
||||
|
||||
async def interpret(self):
|
||||
feature_info = self.gen_feature_info()
|
||||
content = (
|
||||
f"You have received sensor features from {self.name} modality:\n"
|
||||
f"{feature_info}\n\n"
|
||||
f"Please provide your answer for the task among {self.classes_info} "
|
||||
"and the reasoning for your answer.\n"
|
||||
"Note that the sensor features might be wrong due to the data collection or processing.\n"
|
||||
"You can evaluate the quality of the features by checking the examples you have.\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
' "REASON": "<Reasoning for the answer>",\n'
|
||||
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
|
||||
"}\n\n"
|
||||
"Do not include any additional text outside the JSON."
|
||||
)
|
||||
|
||||
response = await self.invoke(content)
|
||||
parsed_response = self.safe_parse_json(response)
|
||||
parsed_response = await self.validate_response(
|
||||
parsed_response, ["REASON", "ANSWER"]
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
async def evaluate(self, target_name, initial_response):
|
||||
initial_response_info = json.dumps(initial_response, indent=2)
|
||||
content = (
|
||||
f"Other agent, <{target_name}> provided the following answer for the same task:\n"
|
||||
f"{initial_response_info}\n\n"
|
||||
"Please evaluate the given reasoning and answer based on your judgement. "
|
||||
"You may either support with it or disagree.\n"
|
||||
"If you agree, explain why the reasoning and answer are valid. "
|
||||
"If you disagree, explain why the reasoning or answer may be flawed, "
|
||||
f"and provide constructive feedback on how <{target_name}> can improve its response.\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
f' "EVALUATION": "<Evaluation to <{target_name}>"\n'
|
||||
"}\n\n"
|
||||
"Do not include any additional text outside the JSON."
|
||||
)
|
||||
|
||||
response = await self.invoke(content, volatile=True)
|
||||
parsed_response = self.safe_parse_json(response)
|
||||
parsed_response = await self.validate_response(
|
||||
parsed_response, ["EVALUATION"], volatile=True
|
||||
)
|
||||
self.clean_volatile_memory()
|
||||
return parsed_response
|
||||
|
||||
async def reflect(self, evaluations):
|
||||
evaluations_info = json.dumps(evaluations, indent=2)
|
||||
content = (
|
||||
f"Other agents have evaluated your answer for the same task:\n"
|
||||
f"{evaluations_info}\n\n"
|
||||
"Please reflect on the evaluations and provide a refined answer for the same task.\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
' "REASON": "<Reasoning for the answer>",\n'
|
||||
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
|
||||
"}\n\n"
|
||||
"Do not include any additional text outside the JSON."
|
||||
)
|
||||
|
||||
response = await self.invoke(content)
|
||||
parsed_response = self.safe_parse_json(response)
|
||||
parsed_response = await self.validate_response(
|
||||
parsed_response, ["REASON", "ANSWER"]
|
||||
)
|
||||
return parsed_response
|
||||
24
core/vote.py
Normal file
24
core/vote.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Any, Dict, List
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def borda_vote(results: List[Dict[str, Any]], borda_p: float = 1.0):
|
||||
parsed = [
|
||||
{"answer": r["answer"], "score": r["score"]}
|
||||
for r in results
|
||||
if r.get("answer") is not None
|
||||
]
|
||||
|
||||
if not parsed:
|
||||
return None, {}
|
||||
|
||||
parsed.sort(key=lambda x: x["score"], reverse=True)
|
||||
|
||||
n = len(parsed)
|
||||
tally: Counter = Counter()
|
||||
for rank, entry in enumerate(parsed, start=1):
|
||||
votes = int((n - rank + 1) ** borda_p)
|
||||
tally[entry["answer"]] += votes
|
||||
|
||||
winner, _ = tally.most_common(1)[0]
|
||||
return winner, dict(tally.most_common())
|
||||
78
experiments/gen_configs.py
Normal file
78
experiments/gen_configs.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Generate per-(method, user) config files.
|
||||
|
||||
Usage:
|
||||
python experiments/gen_configs.py \
|
||||
--dataset sleepedf \
|
||||
--data_path /mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new \
|
||||
--users 00,01,02
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from fire import Fire
|
||||
|
||||
|
||||
METHODS = [
|
||||
"random_fixed_single",
|
||||
"random_dynamic_single",
|
||||
"random_fixed_sc",
|
||||
"random_dynamic_sc",
|
||||
"random_dynamic_borda",
|
||||
"random_fixed_borda",
|
||||
"ours",
|
||||
]
|
||||
|
||||
DEFAULTS = {
|
||||
"queue_size": 5,
|
||||
"num_shot": 1,
|
||||
"update_size": 3,
|
||||
"borda_p": 1.0,
|
||||
"vocab_size": 200064,
|
||||
"model_paths": [
|
||||
"ollama:url:chris.kaist.ac.kr:11437/gpt-oss:20b",
|
||||
"ollama:url:chris.kaist.ac.kr:11438/gpt-oss:20b",
|
||||
"ollama:url:chris.kaist.ac.kr:11439/gpt-oss:20b",
|
||||
"ollama:url:chris.kaist.ac.kr:11440/gpt-oss:20b",
|
||||
"ollama:url:chris.kaist.ac.kr:11441/gpt-oss:20b",
|
||||
"ollama:url:chris.kaist.ac.kr:11442/gpt-oss:20b",
|
||||
"ollama:url:chris.kaist.ac.kr:11443/gpt-oss:20b",
|
||||
"ollama:url:chris.kaist.ac.kr:11444/gpt-oss:20b",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def main(dataset: str, data_path: str, users: str):
|
||||
"""
|
||||
Args:
|
||||
dataset: Dataset name (e.g. "sleepedf").
|
||||
data_path: Absolute path to the dataset.
|
||||
users: Comma-separated user IDs (e.g. "00,01,02").
|
||||
"""
|
||||
users = [u.strip() for u in users.split(",")]
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
for method in METHODS:
|
||||
for user in users:
|
||||
config_dir = os.path.join(project_root, "config", method, dataset)
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
|
||||
config = dict(DEFAULTS)
|
||||
config["data_path"] = data_path
|
||||
config["target_user"] = user
|
||||
config["log_path"] = (
|
||||
f"/mnt/sting/hjyoon/projects/llm_personalization/logs"
|
||||
f"/{method}/{dataset}/{user}"
|
||||
)
|
||||
|
||||
config_path = os.path.join(config_dir, f"user{user}.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
print(f" {config_path}")
|
||||
|
||||
print(f"\nGenerated {len(METHODS) * len(users)} configs.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
75
experiments/run.sh
Executable file
75
experiments/run.sh
Executable file
@@ -0,0 +1,75 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Generic experiment runner. Finds configs under config/<method>/<dataset>/
|
||||
# and runs the corresponding Python script.
|
||||
#
|
||||
# Usage:
|
||||
# bash experiments/run.sh <method> <dataset> # one method, all users
|
||||
# bash experiments/run.sh <method> <dataset> <config> # one config file
|
||||
# bash experiments/run.sh all <dataset> # all methods
|
||||
#
|
||||
# Examples:
|
||||
# bash experiments/run.sh ours sleepedf
|
||||
# bash experiments/run.sh random_fixed_sc sleepedf
|
||||
# bash experiments/run.sh random_fixed_sc sleepedf user00.yaml
|
||||
# bash experiments/run.sh all sleepedf
|
||||
#
|
||||
|
||||
set -e
|
||||
|
||||
METHOD="${1:?Usage: $0 <method|all> <dataset> [config_file]}"
|
||||
DATASET="${2:?Usage: $0 <method|all> <dataset> [config_file]}"
|
||||
CONFIG_FILE="$3"
|
||||
|
||||
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
CONFIG_ROOT="$PROJECT_ROOT/config"
|
||||
|
||||
run_config() {
|
||||
local method="$1"
|
||||
local config="$2"
|
||||
|
||||
echo "============================================"
|
||||
echo " Method: $method | Config: $(basename "$config")"
|
||||
echo "============================================"
|
||||
|
||||
if [ "$method" = "ours" ]; then
|
||||
python "$PROJECT_ROOT/run.py" --config_path "$config"
|
||||
else
|
||||
python "$PROJECT_ROOT/baselines/${method}.py" --config_path "$config"
|
||||
fi
|
||||
echo ""
|
||||
}
|
||||
|
||||
if [ "$METHOD" != "all" ] && [ -n "$CONFIG_FILE" ]; then
|
||||
# Single config
|
||||
config="$CONFIG_ROOT/$METHOD/$DATASET/$CONFIG_FILE"
|
||||
if [ ! -f "$config" ]; then
|
||||
echo "Error: config not found: $config"
|
||||
exit 1
|
||||
fi
|
||||
run_config "$METHOD" "$config"
|
||||
|
||||
elif [ "$METHOD" != "all" ]; then
|
||||
# All configs for one method
|
||||
config_dir="$CONFIG_ROOT/$METHOD/$DATASET"
|
||||
if [ ! -d "$config_dir" ]; then
|
||||
echo "Error: config directory not found: $config_dir"
|
||||
exit 1
|
||||
fi
|
||||
for config in "$config_dir"/*.yaml; do
|
||||
run_config "$METHOD" "$config"
|
||||
done
|
||||
|
||||
else
|
||||
# All methods
|
||||
for method_dir in "$CONFIG_ROOT"/*/; do
|
||||
method="$(basename "$method_dir")"
|
||||
dataset_dir="$method_dir$DATASET"
|
||||
[ -d "$dataset_dir" ] || continue
|
||||
for config in "$dataset_dir"/*.yaml; do
|
||||
run_config "$method" "$config"
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
echo "All experiments complete."
|
||||
324
preprocess/preprocess_GLOBEM.py
Normal file
324
preprocess/preprocess_GLOBEM.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import os
|
||||
import json
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from tqdm import tqdm
|
||||
from fire import Fire
|
||||
from datasets import Dataset
|
||||
from datetime import timedelta
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
GLOBEM_PATH = "/mnt/sting/hjyoon/projects/llm_personalization/dataset/GLOBEM/physionet.org/files/globem/1.1"
|
||||
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/GLOBEM"
|
||||
|
||||
PHASES = {
|
||||
"INS-W": [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
FEATURE_TYPES = ["f_loc", "f_screen", "f_slp", "f_steps"]
|
||||
|
||||
FEATURE_COLUMNS = {
|
||||
"f_loc": [
|
||||
"phone_locations_barnett_avgflightdur", "phone_locations_barnett_avgflightlen",
|
||||
"phone_locations_barnett_circdnrtn", "phone_locations_barnett_disttravelled",
|
||||
"phone_locations_barnett_hometime", "phone_locations_barnett_maxdiam",
|
||||
"phone_locations_barnett_maxhomedist", "phone_locations_barnett_probpause",
|
||||
"phone_locations_barnett_rog", "phone_locations_barnett_siglocentropy",
|
||||
"phone_locations_barnett_siglocsvisited", "phone_locations_barnett_stdflightdur",
|
||||
"phone_locations_barnett_stdflightlen", "phone_locations_barnett_wkenddayrtn",
|
||||
"phone_locations_doryab_avglengthstayatclusters", "phone_locations_doryab_avgspeed",
|
||||
"phone_locations_doryab_homelabel", "phone_locations_doryab_locationentropy",
|
||||
"phone_locations_doryab_locationvariance", "phone_locations_doryab_loglocationvariance",
|
||||
"phone_locations_doryab_maxlengthstayatclusters", "phone_locations_doryab_minlengthstayatclusters",
|
||||
"phone_locations_doryab_movingtostaticratio", "phone_locations_doryab_normalizedlocationentropy",
|
||||
"phone_locations_doryab_numberlocationtransitions", "phone_locations_doryab_numberofsignificantplaces",
|
||||
"phone_locations_doryab_outlierstimepercent", "phone_locations_doryab_radiusgyration",
|
||||
"phone_locations_doryab_stdlengthstayatclusters", "phone_locations_doryab_timeathome",
|
||||
"phone_locations_doryab_timeattop1location", "phone_locations_doryab_timeattop2location",
|
||||
"phone_locations_doryab_timeattop3location", "phone_locations_doryab_totaldistance",
|
||||
"phone_locations_doryab_varspeed",
|
||||
"phone_locations_locmap_duration_in_locmap_study", "phone_locations_locmap_percent_in_locmap_study",
|
||||
"phone_locations_locmap_duration_in_locmap_exercise", "phone_locations_locmap_percent_in_locmap_exercise",
|
||||
"phone_locations_locmap_duration_in_locmap_greens", "phone_locations_locmap_percent_in_locmap_greens",
|
||||
],
|
||||
"f_screen": [
|
||||
"phone_screen_rapids_countepisodeunlock", "phone_screen_rapids_sumdurationunlock",
|
||||
"phone_screen_rapids_maxdurationunlock", "phone_screen_rapids_mindurationunlock",
|
||||
"phone_screen_rapids_avgdurationunlock", "phone_screen_rapids_stddurationunlock",
|
||||
"phone_screen_rapids_firstuseafter00unlock",
|
||||
"phone_screen_rapids_countepisodeunlock_locmap_exercise", "phone_screen_rapids_sumdurationunlock_locmap_exercise",
|
||||
"phone_screen_rapids_maxdurationunlock_locmap_exercise", "phone_screen_rapids_mindurationunlock_locmap_exercise",
|
||||
"phone_screen_rapids_avgdurationunlock_locmap_exercise", "phone_screen_rapids_stddurationunlock_locmap_exercise",
|
||||
"phone_screen_rapids_firstuseafter00unlock_locmap_exercise",
|
||||
"phone_screen_rapids_countepisodeunlock_locmap_greens", "phone_screen_rapids_sumdurationunlock_locmap_greens",
|
||||
"phone_screen_rapids_maxdurationunlock_locmap_greens", "phone_screen_rapids_mindurationunlock_locmap_greens",
|
||||
"phone_screen_rapids_avgdurationunlock_locmap_greens", "phone_screen_rapids_stddurationunlock_locmap_greens",
|
||||
"phone_screen_rapids_firstuseafter00unlock_locmap_greens",
|
||||
"phone_screen_rapids_countepisodeunlock_locmap_living", "phone_screen_rapids_sumdurationunlock_locmap_living",
|
||||
"phone_screen_rapids_maxdurationunlock_locmap_living", "phone_screen_rapids_mindurationunlock_locmap_living",
|
||||
"phone_screen_rapids_avgdurationunlock_locmap_living", "phone_screen_rapids_stddurationunlock_locmap_living",
|
||||
"phone_screen_rapids_firstuseafter00unlock_locmap_living",
|
||||
"phone_screen_rapids_countepisodeunlock_locmap_study", "phone_screen_rapids_sumdurationunlock_locmap_study",
|
||||
"phone_screen_rapids_maxdurationunlock_locmap_study", "phone_screen_rapids_mindurationunlock_locmap_study",
|
||||
"phone_screen_rapids_avgdurationunlock_locmap_study", "phone_screen_rapids_stddurationunlock_locmap_study",
|
||||
"phone_screen_rapids_firstuseafter00unlock_locmap_study",
|
||||
"phone_screen_rapids_countepisodeunlock_locmap_home", "phone_screen_rapids_sumdurationunlock_locmap_home",
|
||||
"phone_screen_rapids_maxdurationunlock_locmap_home", "phone_screen_rapids_mindurationunlock_locmap_home",
|
||||
"phone_screen_rapids_avgdurationunlock_locmap_home", "phone_screen_rapids_stddurationunlock_locmap_home",
|
||||
"phone_screen_rapids_firstuseafter00unlock_locmap_home",
|
||||
],
|
||||
"f_slp": [
|
||||
"fitbit_sleep_summary_rapids_sumdurationafterwakeupmain", "fitbit_sleep_summary_rapids_sumdurationasleepmain",
|
||||
"fitbit_sleep_summary_rapids_sumdurationawakemain", "fitbit_sleep_summary_rapids_sumdurationtofallasleepmain",
|
||||
"fitbit_sleep_summary_rapids_sumdurationinbedmain", "fitbit_sleep_summary_rapids_avgefficiencymain",
|
||||
"fitbit_sleep_summary_rapids_avgdurationafterwakeupmain", "fitbit_sleep_summary_rapids_avgdurationasleepmain",
|
||||
"fitbit_sleep_summary_rapids_avgdurationawakemain", "fitbit_sleep_summary_rapids_avgdurationtofallasleepmain",
|
||||
"fitbit_sleep_summary_rapids_avgdurationinbedmain", "fitbit_sleep_summary_rapids_countepisodemain",
|
||||
"fitbit_sleep_summary_rapids_firstbedtimemain", "fitbit_sleep_summary_rapids_lastbedtimemain",
|
||||
"fitbit_sleep_summary_rapids_firstwaketimemain", "fitbit_sleep_summary_rapids_lastwaketimemain",
|
||||
"fitbit_sleep_intraday_rapids_avgdurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_avgdurationawakeunifiedmain",
|
||||
"fitbit_sleep_intraday_rapids_maxdurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_maxdurationawakeunifiedmain",
|
||||
"fitbit_sleep_intraday_rapids_sumdurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_sumdurationawakeunifiedmain",
|
||||
"fitbit_sleep_intraday_rapids_countepisodeasleepunifiedmain", "fitbit_sleep_intraday_rapids_countepisodeawakeunifiedmain",
|
||||
"fitbit_sleep_intraday_rapids_stddurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_stddurationawakeunifiedmain",
|
||||
"fitbit_sleep_intraday_rapids_mindurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_mindurationawakeunifiedmain",
|
||||
"fitbit_sleep_intraday_rapids_mediandurationasleepunifiedmain", "fitbit_sleep_intraday_rapids_mediandurationawakeunifiedmain",
|
||||
"fitbit_sleep_intraday_rapids_ratiocountasleepunifiedwithinmain", "fitbit_sleep_intraday_rapids_ratiocountawakeunifiedwithinmain",
|
||||
"fitbit_sleep_intraday_rapids_ratiodurationasleepunifiedwithinmain", "fitbit_sleep_intraday_rapids_ratiodurationawakeunifiedwithinmain",
|
||||
],
|
||||
"f_steps": [
|
||||
"fitbit_steps_summary_rapids_maxsumsteps", "fitbit_steps_summary_rapids_minsumsteps",
|
||||
"fitbit_steps_summary_rapids_avgsumsteps", "fitbit_steps_summary_rapids_mediansumsteps",
|
||||
"fitbit_steps_summary_rapids_stdsumsteps",
|
||||
"fitbit_steps_intraday_rapids_sumsteps", "fitbit_steps_intraday_rapids_maxsteps",
|
||||
"fitbit_steps_intraday_rapids_minsteps", "fitbit_steps_intraday_rapids_avgsteps",
|
||||
"fitbit_steps_intraday_rapids_stdsteps",
|
||||
"fitbit_steps_intraday_rapids_countepisodesedentarybout", "fitbit_steps_intraday_rapids_sumdurationsedentarybout",
|
||||
"fitbit_steps_intraday_rapids_maxdurationsedentarybout", "fitbit_steps_intraday_rapids_mindurationsedentarybout",
|
||||
"fitbit_steps_intraday_rapids_avgdurationsedentarybout", "fitbit_steps_intraday_rapids_stddurationsedentarybout",
|
||||
"fitbit_steps_intraday_rapids_countepisodeactivebout", "fitbit_steps_intraday_rapids_sumdurationactivebout",
|
||||
"fitbit_steps_intraday_rapids_maxdurationactivebout", "fitbit_steps_intraday_rapids_mindurationactivebout",
|
||||
"fitbit_steps_intraday_rapids_avgdurationactivebout", "fitbit_steps_intraday_rapids_stddurationactivebout",
|
||||
],
|
||||
}
|
||||
|
||||
TIME_SEGMENT = "allday"
|
||||
WINDOW_DAYS = 28
|
||||
LABEL_COL = "dep"
|
||||
PREDICTION_TARGET = "dep_weekly"
|
||||
CLASS_DICT = {True: "depressed", False: "not_depressed"}
|
||||
|
||||
PRE_SURVEY_COLS = [
|
||||
"UCLA_10items_PRE", "SocialFit_PRE",
|
||||
"2waySSS_receiving_emotional_PRE", "2waySSS_giving_emotional_PRE",
|
||||
"2waySSS_giving_instrumental_PRE", "2waySSS_receiving_instrumental_PRE",
|
||||
"ERQ_reappraisal_PRE", "ERQ_suppression_PRE",
|
||||
"BRS_PRE", "CHIPS_PRE", "PSS_10items_PRE", "STAIS_PRE", "MAAS_7items_PRE",
|
||||
"CESD_9items_PRE", "CESD_10items_PRE",
|
||||
"BFI10_extroversion_PRE", "BFI10_agreeableness_PRE",
|
||||
"BFI10_conscientiousness_PRE", "BFI10_neuroticism_PRE", "BFI10_openness_PRE",
|
||||
]
|
||||
|
||||
|
||||
def get_feature_col_names():
|
||||
"""Build the list of full column names: f_type:feature_name:time_segment"""
|
||||
cols = []
|
||||
for ft in FEATURE_TYPES:
|
||||
for feat in FEATURE_COLUMNS[ft]:
|
||||
cols.append(f"{ft}:{feat}:{TIME_SEGMENT}")
|
||||
return cols
|
||||
|
||||
|
||||
def store_task_metadata(path):
|
||||
task_metadata = {
|
||||
"task": (
|
||||
'Classify the user\'s depression status: ["depressed", "not_depressed"], '
|
||||
"based on passive sensing data collected from a smartphone and a wearable fitness tracker."
|
||||
),
|
||||
"class": {
|
||||
"depressed": "The user shows depressive symptoms based on self-reported weekly survey responses.",
|
||||
"not_depressed": "The user does not show depressive symptoms based on self-reported weekly survey responses.",
|
||||
},
|
||||
"data": (
|
||||
"Data were collected over a three-month study period from college students at a university. "
|
||||
"Participants carried a smartphone and wore a Fitbit fitness tracker 24x7. "
|
||||
"Passive sensing data includes GPS location, phone screen usage, Fitbit sleep, and Fitbit physical activity. "
|
||||
"Features were extracted using the RAPIDS toolkit and computed daily over multiple time segments. "
|
||||
"Each sample represents the last day of a 28-day observation window preceding a depression label date. "
|
||||
"Each feature is named using the format 'sensor_type:feature_name:time_segment'."
|
||||
),
|
||||
"feature": (
|
||||
"Location features (f_loc) include GPS-based metrics such as home time, distance travelled, "
|
||||
"radius of gyration, location entropy, number of significant places, and time spent at various locations. "
|
||||
"Phone usage features (f_screen) include unlock episode counts, durations, and location-specific phone usage patterns. "
|
||||
"Sleep features (f_slp) include Fitbit-derived metrics such as sleep duration, efficiency, "
|
||||
"time to fall asleep, bedtime/waketime, and intraday sleep/wake episode statistics. "
|
||||
"Physical activity features (f_steps) include step counts, sedentary bout statistics, and active bout statistics. "
|
||||
"All features use the 'allday' time segment (24 hours from midnight to midnight)."
|
||||
),
|
||||
}
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(task_metadata, f, indent=2)
|
||||
|
||||
|
||||
def store_user_metadata(globem_path, out_path):
|
||||
"""Build per-user metadata from platform.csv and pre.csv across all phases."""
|
||||
user_metadata = {}
|
||||
|
||||
for institution, phases in PHASES.items():
|
||||
for phase in phases:
|
||||
ds_dir = os.path.join(globem_path, f"{institution}_{phase}")
|
||||
platform_path = os.path.join(ds_dir, "ParticipantsInfoData", "platform.csv")
|
||||
pre_path = os.path.join(ds_dir, "SurveyData", "pre.csv")
|
||||
|
||||
if not os.path.exists(platform_path):
|
||||
print(f" Skipping {institution}_{phase}: platform.csv not found")
|
||||
continue
|
||||
|
||||
df_platform = pd.read_csv(platform_path)
|
||||
if "Unnamed: 0" in df_platform.columns:
|
||||
df_platform = df_platform.drop(columns=["Unnamed: 0"])
|
||||
df_platform = df_platform.set_index("pid")
|
||||
|
||||
df_pre = None
|
||||
if os.path.exists(pre_path):
|
||||
df_pre = pd.read_csv(pre_path)
|
||||
if "Unnamed: 0" in df_pre.columns:
|
||||
df_pre = df_pre.drop(columns=["Unnamed: 0"])
|
||||
df_pre = df_pre.set_index("pid")
|
||||
|
||||
for pid in df_platform.index:
|
||||
user_key = f"{pid}#{institution}_{phase}"
|
||||
platform = str(df_platform.loc[pid, "platform"]).split(";")[0]
|
||||
meta = {"platform": platform, "phase": phase}
|
||||
|
||||
if df_pre is not None and pid in df_pre.index:
|
||||
row = df_pre.loc[pid]
|
||||
for col in PRE_SURVEY_COLS:
|
||||
if col in row.index:
|
||||
val = row[col]
|
||||
meta[col] = round(float(val), 4) if pd.notna(val) else None
|
||||
|
||||
user_metadata[user_key] = meta
|
||||
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(user_metadata, f, indent=2)
|
||||
|
||||
|
||||
def process_single_dataset(globem_path, institution, phase, feature_cols):
|
||||
"""Process one dataset (institution + phase) and return per-user sample lists."""
|
||||
ds_dir = os.path.join(globem_path, f"{institution}_{phase}")
|
||||
rapids_path = os.path.join(ds_dir, "FeatureData", "rapids.csv")
|
||||
label_path = os.path.join(ds_dir, "SurveyData", "dep_weekly.csv")
|
||||
platform_path = os.path.join(ds_dir, "ParticipantsInfoData", "platform.csv")
|
||||
|
||||
for p in [rapids_path, label_path, platform_path]:
|
||||
if not os.path.exists(p):
|
||||
print(f" Skipping {institution}_{phase}: {os.path.basename(p)} not found")
|
||||
return {}
|
||||
|
||||
print(f" Loading {institution}_{phase}...")
|
||||
df_rapids = pd.read_csv(rapids_path, low_memory=False)
|
||||
if "Unnamed: 0" in df_rapids.columns:
|
||||
df_rapids = df_rapids.drop(columns=["Unnamed: 0"])
|
||||
df_rapids["date"] = pd.to_datetime(df_rapids["date"])
|
||||
|
||||
df_label = pd.read_csv(label_path)
|
||||
if "Unnamed: 0" in df_label.columns:
|
||||
df_label = df_label.drop(columns=["Unnamed: 0"])
|
||||
df_label["date"] = pd.to_datetime(df_label["date"])
|
||||
df_label = df_label.dropna(subset=[LABEL_COL])
|
||||
df_label = df_label.drop_duplicates(["pid", "date"], keep="last")
|
||||
|
||||
available_cols = [c for c in feature_cols if c in df_rapids.columns]
|
||||
if len(available_cols) == 0:
|
||||
print(f" Skipping {institution}_{phase}: no matching feature columns")
|
||||
return {}
|
||||
|
||||
user_data = {}
|
||||
phase_str = str(phase)
|
||||
|
||||
pids_few_response = df_label.groupby("pid")["date"].count()
|
||||
valid_pids = set(pids_few_response[pids_few_response >= 2].index)
|
||||
|
||||
for _, row in tqdm(df_label.iterrows(), total=len(df_label),
|
||||
desc=f" {institution}_{phase}", leave=False):
|
||||
pid = row["pid"]
|
||||
if pid not in valid_pids:
|
||||
continue
|
||||
|
||||
date_end = row["date"]
|
||||
date_start = date_end - timedelta(days=WINDOW_DAYS - 1)
|
||||
label_val = row[LABEL_COL]
|
||||
|
||||
df_window = df_rapids[
|
||||
(df_rapids["pid"] == pid)
|
||||
& (df_rapids["date"] >= date_start)
|
||||
& (df_rapids["date"] <= date_end)
|
||||
]
|
||||
if df_window.empty:
|
||||
continue
|
||||
|
||||
last_day = df_window.sort_values("date").iloc[-1]
|
||||
|
||||
features = {}
|
||||
for col in available_cols:
|
||||
val = last_day[col]
|
||||
if pd.notna(val):
|
||||
features[col] = float(val)
|
||||
else:
|
||||
features[col] = None
|
||||
|
||||
user_key = f"{pid}#{institution}_{phase}"
|
||||
if user_key not in user_data:
|
||||
user_data[user_key] = []
|
||||
|
||||
user_data[user_key].append(dict(
|
||||
user_id=user_key,
|
||||
session_id=phase_str,
|
||||
idx=len(user_data[user_key]),
|
||||
date=str(date_end.date()),
|
||||
label=CLASS_DICT[label_val],
|
||||
features=features,
|
||||
))
|
||||
|
||||
return user_data
|
||||
|
||||
|
||||
def run(path=GLOBEM_PATH, out_dir=OUT_DIR):
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
|
||||
feature_cols = get_feature_col_names()
|
||||
print(f"Using {len(feature_cols)} feature columns ({TIME_SEGMENT} time segment)")
|
||||
|
||||
print("Saving task metadata...")
|
||||
store_task_metadata(os.path.join(out_dir, "task_metadata.json"))
|
||||
|
||||
print("Saving user metadata...")
|
||||
store_user_metadata(path, os.path.join(out_dir, "user_metadata.json"))
|
||||
|
||||
all_user_data = {}
|
||||
for institution, phases in PHASES.items():
|
||||
for phase in phases:
|
||||
user_data = process_single_dataset(path, institution, phase, feature_cols)
|
||||
for user_key, samples in user_data.items():
|
||||
all_user_data.setdefault(user_key, []).extend(samples)
|
||||
|
||||
print(f"\nSaving datasets for {len(all_user_data)} users...")
|
||||
total_samples = 0
|
||||
for user_key, data in sorted(all_user_data.items()):
|
||||
safe_name = user_key.replace("#", "_")
|
||||
user_dir = os.path.join(out_dir, safe_name)
|
||||
dataset = Dataset.from_list(data)
|
||||
dataset.save_to_disk(user_dir)
|
||||
total_samples += len(data)
|
||||
|
||||
print(f"\nDone. {total_samples} total samples from {len(all_user_data)} users saved to {out_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(run)
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import warnings
|
||||
import numpy as np
|
||||
import neurokit2 as nk
|
||||
import pandas as pd
|
||||
|
||||
from tqdm import tqdm
|
||||
from fire import Fire
|
||||
@@ -21,7 +22,7 @@ warnings.simplefilter("ignore", NeuroKitWarning)
|
||||
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
|
||||
|
||||
SLEEPEDF_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/sleep-cassette/"
|
||||
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_new"
|
||||
EPOCH_SEC_SIZE = 30
|
||||
SAMPLING_RATE = 100
|
||||
|
||||
@@ -45,8 +46,8 @@ ann2label = {
|
||||
}
|
||||
|
||||
|
||||
def store_info(info_path):
|
||||
info = {
|
||||
def store_task_metadata(path):
|
||||
task_metadata = {
|
||||
"task": 'Classify the user\'s sleep stage: ["W", "N1", "N2", "N3", "REM"], based on physiological signals collected from wearable sensors.',
|
||||
"class": {
|
||||
"W": "Wakefulness. This includes periods before sleep onset or after final awakening, and short awakenings during the night.",
|
||||
@@ -72,8 +73,21 @@ def store_info(info_path):
|
||||
"Ratio features such as delta/theta, theta/alpha, alpha/beta, and (delta+theta)/(alpha+beta) were also included."
|
||||
),
|
||||
}
|
||||
with open(info_path, "w", encoding="utf-8") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(task_metadata, f, indent=2)
|
||||
|
||||
|
||||
def store_user_metadata(src_path, out_path):
|
||||
# load xls file
|
||||
df = pd.read_excel(src_path)
|
||||
user_metadata = {}
|
||||
for _, row in df.iterrows():
|
||||
user_id = str(row["subject"])
|
||||
age = int(row["age"])
|
||||
sex = int(row["sex (F=1)"])-1 # 0: female, 1: male
|
||||
user_metadata[user_id] = {"age": age, "sex": sex}
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(user_metadata, f, indent=2)
|
||||
|
||||
|
||||
def lowpass_filter(data, cutoff=50, fs=1000, order=4):
|
||||
@@ -197,122 +211,7 @@ def process_by_mod(modality, data):
|
||||
features[f"{modality}_theta/alpha_ratio"] = theta_alpha_ratio
|
||||
features[f"{modality}_alpha/beta_ratio"] = alpha_beta_ratio
|
||||
features[f"{modality}_(delta+theta)/(alpha+beta)_ratio"] = slow_fast_ratio
|
||||
# elif "EOG" in modality:
|
||||
# eog_cleaned = data
|
||||
# try:
|
||||
# eog_cleaned = nk.eog_clean(data, sr)
|
||||
# except IndexError as e:
|
||||
# print(f"Error processing EOG data for {modality}: {e}")
|
||||
# return None
|
||||
# eog_mean = np.mean(eog_cleaned)
|
||||
# eog_std = np.std(eog_cleaned)
|
||||
# eog_var = np.var(eog_cleaned)
|
||||
# features[f"{modality}_mean"] = eog_mean
|
||||
# features[f"{modality}_std"] = eog_std
|
||||
# features[f"{modality}_variance"] = eog_var
|
||||
# dynamic_range = np.max(eog_cleaned) - np.min(eog_cleaned)
|
||||
# features[f"{modality}_dynamic_range"] = dynamic_range
|
||||
# peaks = signal.find_peaks(eog_cleaned - eog_mean, height=3 * eog_std)[0]
|
||||
# features[f"{modality}_num_peaks"] = len(peaks)
|
||||
# zero_crossings = np.where(np.diff(np.sign(eog_cleaned - eog_mean)))[0]
|
||||
# features[f"{modality}_num_zero_crossings"] = len(zero_crossings)
|
||||
# differences = eog_cleaned[1:] - eog_cleaned[:-1]
|
||||
# difference_variance = np.var(differences)
|
||||
# features[f"{modality}_difference_variance"] = difference_variance
|
||||
# features[f"{modality}_num_large_eye_movements"] = count_large_eye_movements(
|
||||
# eog_cleaned, sr, amp_thresh=120, time_thresh=1.5
|
||||
# )
|
||||
# eog_large_movement_removed = remove_large_eye_movements(
|
||||
# eog_cleaned, fs=sr, amp_thresh=120, time_thresh=1.5, pad=0.75
|
||||
# )
|
||||
# differences = eog_large_movement_removed[1:] - eog_large_movement_removed[:-1]
|
||||
# difference_variance = np.var(differences)
|
||||
# features[f"{modality}_difference_variance_without_large_movements"] = (
|
||||
# difference_variance
|
||||
# )
|
||||
# freqs, psd = signal.welch(eog_cleaned, fs=sr, nperseg=sr * 2)
|
||||
# total_idx = np.logical_and(freqs >= 0.5, freqs <= 30)
|
||||
# total_power = np.trapezoid(psd[total_idx], freqs[total_idx])
|
||||
# slow_idx = np.logical_and(freqs >= 0.5, freqs <= 2)
|
||||
# rapid_idx = np.logical_and(freqs >= 2, freqs <= 5)
|
||||
# slow_power = np.trapezoid(psd[slow_idx], freqs[slow_idx])
|
||||
# rapid_power = np.trapezoid(psd[rapid_idx], freqs[rapid_idx])
|
||||
# slow_power_ratio = slow_power / total_power if total_power > 0 else 0
|
||||
# rapid_power_ratio = rapid_power / total_power if total_power > 0 else 0
|
||||
# features[f"{modality}_slow_movement_power_ratio"] = slow_power_ratio
|
||||
# features[f"{modality}_rapid_movement_power_ratio"] = rapid_power_ratio
|
||||
# elif "Resp" in modality:
|
||||
# rsp_signals = data
|
||||
# try:
|
||||
# rsp_signals, _ = nk.rsp_process(data, sampling_rate=sr, method="biosppy")
|
||||
# except IndexError as e:
|
||||
# print(f"Error processing respiration data for {modality}: {e}")
|
||||
# return None
|
||||
# clean = rsp_signals["RSP_Clean"]
|
||||
# phase = rsp_signals["RSP_Phase"]
|
||||
# rate = rsp_signals["RSP_Rate"]
|
||||
# amplitude = rsp_signals["RSP_Amplitude"]
|
||||
# peaks = np.where(rsp_signals["RSP_Peaks"] == 1)[0]
|
||||
# troughs = np.where(rsp_signals["RSP_Troughs"] == 1)[0]
|
||||
# inhale_durations = []
|
||||
# for t in troughs:
|
||||
# next_peaks = peaks[peaks > t]
|
||||
# if len(next_peaks) == 0:
|
||||
# continue
|
||||
# inhale_durations.append((next_peaks[0] - t) / sr)
|
||||
# inhale_durations = np.array(inhale_durations)
|
||||
# exhale_durations = []
|
||||
# for p in peaks:
|
||||
# next_troughs = troughs[troughs > p]
|
||||
# if len(next_troughs) == 0:
|
||||
# continue
|
||||
# exhale_durations.append((next_troughs[0] - p) / sr)
|
||||
# exhale_durations = np.array(exhale_durations)
|
||||
# features[f"{modality}_inhale_duration_mean"] = np.mean(inhale_durations)
|
||||
# features[f"{modality}_inhale_duration_std"] = np.std(inhale_durations)
|
||||
# features[f"{modality}_exhale_duration_mean"] = np.mean(exhale_durations)
|
||||
# features[f"{modality}_exhale_duration_std"] = np.std(exhale_durations)
|
||||
# features[f"{modality}_inhale_exhale_ratio"] = (
|
||||
# np.mean(inhale_durations) / np.mean(exhale_durations)
|
||||
# if np.mean(exhale_durations) > 0
|
||||
# else np.nan
|
||||
# )
|
||||
# features[f"{modality}_stretch"] = np.max(clean) - np.min(clean)
|
||||
# inhale_mask = phase == 1
|
||||
# features[f"{modality}_inspiration_volume"] = np.trapezoid(
|
||||
# amplitude[inhale_mask], dx=1 / sr
|
||||
# )
|
||||
# features[f"{modality}_respiration_rate"] = np.mean(rate)
|
||||
# resp_durations = np.diff(troughs) / sr
|
||||
# features[f"{modality}_respiration_duration"] = np.mean(resp_durations)
|
||||
# elif "EMG" in modality:
|
||||
# emg_mean = np.mean(data)
|
||||
# emg_std = np.std(data)
|
||||
# features[f"{modality}_mean"] = emg_mean
|
||||
# features[f"{modality}_std"] = emg_std
|
||||
# features[f"{modality}_dynamic_range"] = np.max(data) - np.min(data)
|
||||
# features[f"{modality}_absolute_integral"] = np.sum(np.abs(data)) / sr
|
||||
# features[f"{modality}_median"] = np.median(data)
|
||||
# features[f"{modality}_10th_percentile"] = np.percentile(data, 10)
|
||||
# features[f"{modality}_90th_percentile"] = np.percentile(data, 90)
|
||||
|
||||
# peaks, _ = signal.find_peaks(data, height=3 * emg_std)
|
||||
# peak_values = data[peaks]
|
||||
# features[f"{modality}_num_peaks"] = len(peaks)
|
||||
# features[f"{modality}_peak_amplitude_mean"] = (
|
||||
# np.mean(peak_values) if len(peak_values) > 0 else 0
|
||||
# )
|
||||
# features[f"{modality}_peak_amplitude_std"] = (
|
||||
# np.std(peak_values) if len(peak_values) > 0 else 0
|
||||
# )
|
||||
# features[f"{modality}_peak_amplitude_sum"] = (
|
||||
# np.sum(peak_values) if len(peak_values) > 0 else 0
|
||||
# )
|
||||
# features[f"{modality}_peak_amplitude_norm_sum"] = (
|
||||
# np.sum(peak_values) / np.sum(np.abs(data))
|
||||
# if np.sum(np.abs(data)) > 0
|
||||
# else 0
|
||||
# )
|
||||
return features
|
||||
|
||||
|
||||
@@ -450,9 +349,17 @@ def run(path=SLEEPEDF_PATH, out_dir=OUT_DIR, num_examples=1, num_workers=32, see
|
||||
np.random.seed(seed)
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
info_path = os.path.join(out_dir, "info.json")
|
||||
store_info(info_path)
|
||||
print(f"Saved info to {info_path}")
|
||||
|
||||
task_metadata_out_path = os.path.join(out_dir, "task_metadata.json")
|
||||
store_task_metadata(task_metadata_out_path)
|
||||
print(f"Saved info to {task_metadata_out_path}")
|
||||
|
||||
user_metadata_out_path = os.path.join(out_dir, "user_metadata.json")
|
||||
user_metadata_src_path = os.path.join(path, "..", "SC-subjects.xls")
|
||||
store_user_metadata(user_metadata_src_path, user_metadata_out_path)
|
||||
print(f"Saved info to {user_metadata_out_path}")
|
||||
|
||||
|
||||
psg_file_paths = glob(os.path.join(path, "*PSG.edf"))
|
||||
ann_file_paths = glob(os.path.join(path, "*Hypnogram.edf"))
|
||||
psg_file_paths.sort()
|
||||
@@ -466,16 +373,20 @@ def run(path=SLEEPEDF_PATH, out_dir=OUT_DIR, num_examples=1, num_workers=32, see
|
||||
elif basename.startswith("SC41"):
|
||||
filtered_2013_file_paths.append(file_path)
|
||||
|
||||
user_data = {}
|
||||
with Pool(processes=num_workers) as pool:
|
||||
for data in pool.imap_unordered(preprocess, filtered_2013_file_paths):
|
||||
if len(data) == 0:
|
||||
continue
|
||||
user_id = data[0]["user_id"]
|
||||
session_id = data[0]["session_id"]
|
||||
dataset = Dataset.from_list(data)
|
||||
test_dir = os.path.join(out_dir, f"{user_id}", f"{session_id}")
|
||||
dataset.save_to_disk(test_dir)
|
||||
print(f"Saved dataset to {test_dir}")
|
||||
user_data.setdefault(user_id, []).extend(data)
|
||||
|
||||
for user_id, data in user_data.items():
|
||||
dataset = Dataset.from_list(data)
|
||||
test_dir = os.path.join(out_dir, f"{user_id}")
|
||||
dataset.save_to_disk(test_dir)
|
||||
print(f"Saved dataset to {test_dir} ({len(data)} samples, "
|
||||
f"{len(set(d['session_id'] for d in data))} session(s))")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
370
run.py
370
run.py
@@ -1,294 +1,114 @@
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
import yaml
|
||||
import json
|
||||
import numpy as np
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Dict, Any, Optional
|
||||
import time
|
||||
|
||||
from glob import glob
|
||||
import yaml
|
||||
from fire import Fire
|
||||
|
||||
from core.model import load_models
|
||||
from core.data_loader import DataLoader
|
||||
from core.sensing_agent import SensingAgent
|
||||
from core.embedding_index import EmbeddingIndex, create_embedding_index
|
||||
|
||||
_EMBEDDING_INDEX: Optional[EmbeddingIndex] = None
|
||||
|
||||
def init_embedding_index(embedding_path: Optional[str]) -> Optional[EmbeddingIndex]:
|
||||
"""Initialize global embedding index for similarity-based selection."""
|
||||
global _EMBEDDING_INDEX
|
||||
|
||||
if embedding_path is None:
|
||||
_EMBEDDING_INDEX = None
|
||||
return None
|
||||
|
||||
if _EMBEDDING_INDEX is not None:
|
||||
return _EMBEDDING_INDEX
|
||||
|
||||
_EMBEDDING_INDEX = create_embedding_index(embedding_path)
|
||||
return _EMBEDDING_INDEX
|
||||
from core.example_queue import ExampleQueue
|
||||
from core.model import load_models
|
||||
from core.recruiter import Recruiter
|
||||
from core.agent import Agent
|
||||
from core.logger import Logger
|
||||
from core.prompt import gen_system_message, gen_task_message
|
||||
from core.json_utils import safe_parse_json
|
||||
from core.scores import self_certainty
|
||||
from core.vote import borda_vote
|
||||
|
||||
|
||||
def load_user_data_sync(
|
||||
data_path: str,
|
||||
user: str,
|
||||
seed: int,
|
||||
log_path_base: str,
|
||||
selection_criteria: str = "out_random",
|
||||
num_examples: int = 1,
|
||||
embedding_index: Optional[EmbeddingIndex] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Load data for a single user with specified selection criteria."""
|
||||
print(f"[DATA LOADING] Starting: user={user}, seed={seed}, criteria={selection_criteria}")
|
||||
|
||||
# Set random seed for reproducibility
|
||||
np.random.seed(seed)
|
||||
|
||||
data_loader = DataLoader(
|
||||
data_path,
|
||||
user,
|
||||
selection_criteria=selection_criteria,
|
||||
num_examples=num_examples,
|
||||
embedding_index=embedding_index,
|
||||
async def run(config_path: str):
|
||||
print("[Main] Loading config")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
logger = Logger(config.get("log_path"))
|
||||
logger.log_config(config)
|
||||
logger.log("[Main] Loaded config")
|
||||
|
||||
logger.log("[Main] Loading data loader")
|
||||
dataloader = DataLoader(config.get("data_path"), config.get("target_user"))
|
||||
logger.log("[Main] Loaded data loader")
|
||||
|
||||
logger.log("[Main] Initializing example queue")
|
||||
example_q = ExampleQueue(
|
||||
queue_size=config.get("queue_size"),
|
||||
logger=logger,
|
||||
)
|
||||
if not data_loader.is_valid:
|
||||
print(f"[DATA LOADING] Skipping invalid user: {user}")
|
||||
return []
|
||||
|
||||
tasks = []
|
||||
idx = 0
|
||||
dataset_size = len(data_loader)
|
||||
print(f"[DATA LOADING] User {user} has {dataset_size} samples")
|
||||
|
||||
for sample, examples in data_loader:
|
||||
if idx % 10 != 0:
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
log_path = os.path.join(log_path_base, user, f"{idx:02d}", str(seed))
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
|
||||
kwargs = {
|
||||
"task_info": data_loader.get_task_info(),
|
||||
"classes_info": data_loader.get_classes_info(),
|
||||
"sensor_info": data_loader.get_sensor_info(),
|
||||
"sample": sample,
|
||||
"examples": examples,
|
||||
"log_path": log_path,
|
||||
"ground_truth": sample["label"],
|
||||
}
|
||||
tasks.append(kwargs)
|
||||
idx += 1
|
||||
|
||||
print(f"[DATA LOADING] Completed user {user}, seed {seed}: {len(tasks)} tasks")
|
||||
return tasks
|
||||
recruiter = Recruiter(
|
||||
source_dataset=dataloader.get_source_dataset(),
|
||||
source_users=dataloader.get_source_users(),
|
||||
num_shot=config.get("num_shot"),
|
||||
classes=dataloader.get_classes(),
|
||||
logger=logger,
|
||||
)
|
||||
list_examples = recruiter.recruit(config.get("queue_size"))
|
||||
example_q.update([], list_examples)
|
||||
logger.log("[Main] Initialized example queue")
|
||||
|
||||
logger.log("[Main] Loading model pool")
|
||||
model_pool = load_models(config.get("model_paths"))
|
||||
logger.log("[Main] Loaded model pool")
|
||||
|
||||
async def load_data_parallel(config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Load data for all users in parallel with configurable selection criteria."""
|
||||
print("[DATA LOADING] Starting parallel data loading...")
|
||||
|
||||
# Get configuration parameters
|
||||
selection_criteria = config.get("selection_criteria", "out_random")
|
||||
embedding_path = config.get("embedding_path", None)
|
||||
num_examples = config.get("num_examples", 1)
|
||||
|
||||
# Initialize embedding index if needed for similarity-based selection
|
||||
embedding_index = None
|
||||
if selection_criteria in ["out_similar", "in_similar"]:
|
||||
if embedding_path is None:
|
||||
print(f"[WARNING] {selection_criteria} requires embedding_path in config")
|
||||
print("[WARNING] Falling back to random selection")
|
||||
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
|
||||
else:
|
||||
embedding_index = init_embedding_index(embedding_path)
|
||||
if embedding_index is None:
|
||||
print(f"[WARNING] Failed to load embeddings, falling back to random")
|
||||
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
|
||||
|
||||
user_paths = glob(os.path.join(config["data_path"], "*"))
|
||||
user_paths = [path for path in user_paths if os.path.isdir(path)]
|
||||
users = [path.split("/")[-1] for path in user_paths]
|
||||
|
||||
print(f"[DATA LOADING] Found {len(users)} users: {users}")
|
||||
print(f"[DATA LOADING] Selection criteria: {selection_criteria}")
|
||||
print(f"[DATA LOADING] Num examples per class: {num_examples}")
|
||||
|
||||
max_workers = config.get("data_workers", 96)
|
||||
print(f"[DATA LOADING] Using {max_workers} workers for data loading")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
loop = asyncio.get_event_loop()
|
||||
futures = []
|
||||
|
||||
for seed in range(config["num_seeds"]):
|
||||
for user in users:
|
||||
future = loop.run_in_executor(
|
||||
executor,
|
||||
load_user_data_sync,
|
||||
config["data_path"],
|
||||
user,
|
||||
seed,
|
||||
config["log_path"],
|
||||
selection_criteria,
|
||||
num_examples,
|
||||
embedding_index,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
print(f"[DATA LOADING] Created {len(futures)} parallel data loading tasks")
|
||||
results = await asyncio.gather(*futures, return_exceptions=True)
|
||||
|
||||
# Flatten results and filter out exceptions
|
||||
all_tasks = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
print(f"[DATA LOADING] Error loading data: {result}")
|
||||
else:
|
||||
all_tasks.extend(result)
|
||||
logger.log("[Main] Initializing agent")
|
||||
agent = Agent(
|
||||
model_pool=model_pool,
|
||||
logger=logger,
|
||||
)
|
||||
system_message = gen_system_message(metadata=dataloader.get_task_metadata())
|
||||
agent.set_system_message(system_message)
|
||||
logger.log("[Main] Initialized agent")
|
||||
|
||||
print(f"[DATA LOADING] Total tasks: {len(all_tasks)}")
|
||||
return all_tasks
|
||||
async def process_example_set(sample_idx, example_idx, example_set, sample):
|
||||
logger.log(f"[Main] Processing {sample_idx} - {example_idx} (queue index)")
|
||||
try:
|
||||
task_message = gen_task_message(sample, example_set)
|
||||
response, logprobs = await agent.solve(
|
||||
task_message, sample_idx, example_idx
|
||||
)
|
||||
score = self_certainty(logprobs, vocab_size=config.get("vocab_size"))
|
||||
parsed = safe_parse_json(response)
|
||||
answer = parsed.get("ANSWER") if parsed else None
|
||||
logger.log(
|
||||
f"[Main] Done {sample_idx} - {example_idx}: "
|
||||
f"answer={answer}, score={score:.4f}"
|
||||
)
|
||||
return {"example_set": example_set, "answer": answer, "score": score}
|
||||
except Exception as e:
|
||||
logger.log(
|
||||
f"[Main] Error {sample_idx} - {example_idx}: {e}",
|
||||
filename="errors.txt",
|
||||
)
|
||||
return {
|
||||
"example_set": example_set,
|
||||
"answer": None,
|
||||
"score": float("-inf"),
|
||||
}
|
||||
|
||||
|
||||
async def run_parallel(kwargs_list: List[Dict[str, Any]], model_pool, config: Dict[str, Any]) -> None:
|
||||
"""Run classification tasks in parallel using the model pool."""
|
||||
print(f"[EXECUTION] Starting {len(kwargs_list)} classification tasks...")
|
||||
|
||||
tasks = []
|
||||
for kwargs in kwargs_list:
|
||||
agent = SensingAgent(
|
||||
name="EEG sensing",
|
||||
model_pool=model_pool,
|
||||
task_info=kwargs["task_info"],
|
||||
classes_info=kwargs["classes_info"],
|
||||
sensor_info=kwargs["sensor_info"],
|
||||
sample=kwargs["sample"],
|
||||
examples=kwargs["examples"],
|
||||
log_path=kwargs["log_path"],
|
||||
)
|
||||
task = asyncio.create_task(agent.solve(kwargs["sample"], kwargs["examples"], kwargs["ground_truth"]))
|
||||
tasks.append(task)
|
||||
await asyncio.gather(*tasks)
|
||||
print("[EXECUTION] All tasks completed")
|
||||
|
||||
|
||||
def run(config_path: str) -> None:
|
||||
"""
|
||||
Main entry point for running experiments.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
"""
|
||||
print(f"[MAIN] Loading config from: {config_path}")
|
||||
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||
|
||||
print("=" * 60)
|
||||
print("EXPERIMENT CONFIGURATION")
|
||||
print("=" * 60)
|
||||
print(f" Data path: {config.get('data_path', 'N/A')}")
|
||||
print(f" Log path: {config.get('log_path', 'N/A')}")
|
||||
print(f" Selection criteria: {config.get('selection_criteria', 'out_random')}")
|
||||
print(f" Num examples: {config.get('num_examples', 1)}")
|
||||
print(f" Num seeds: {config.get('num_seeds', 1)}")
|
||||
print(f" Embedding path: {config.get('embedding_path', 'N/A')}")
|
||||
print(f" Num models: {len(config.get('models', []))}")
|
||||
print("=" * 60)
|
||||
|
||||
model_pool = load_models(config["models"])
|
||||
|
||||
kwargs_list = asyncio.run(load_data_parallel(config))
|
||||
|
||||
if len(kwargs_list) == 0:
|
||||
print("[ERROR] No valid tasks to run. Check data paths and configuration.")
|
||||
return
|
||||
|
||||
# Warmup models
|
||||
print("[MAIN] Warming up models...")
|
||||
model_pool.warmup()
|
||||
|
||||
# Run experiments
|
||||
print("[MAIN] Starting experiments...")
|
||||
logger.log("[Main] Starting main loop")
|
||||
start_time = time.time()
|
||||
asyncio.run(run_parallel(kwargs_list, model_pool, config))
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
print(f"[MAIN] Experiment completed in {elapsed:.2f} seconds")
|
||||
print(f"[MAIN] Results saved to: {config['log_path']}")
|
||||
for idx, sample in enumerate(dataloader):
|
||||
logger.log(f"[Main] Processing sample {idx} / {len(dataloader)} (sample index)")
|
||||
tasks = [
|
||||
process_example_set(idx, example_idx, example_set, sample)
|
||||
for example_idx, example_set in enumerate(example_q)
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
winner, tally = borda_vote(results, config.get("borda_p", 1.0))
|
||||
ground_truth = sample["label"]
|
||||
if winner is not None:
|
||||
tally_str = ", ".join(f"{ans}: {v}" for ans, v in tally.items())
|
||||
logger.log(f"[Vote] votes={{ {tally_str} }}")
|
||||
logger.log_result(idx, winner, ground_truth)
|
||||
else:
|
||||
logger.log(f"[Vote] Sample {idx} | no valid answer parsed, skipping")
|
||||
recruiter.update_strategy(results)
|
||||
list_examples = recruiter.recruit(num_example_set=config.get("update_size"))
|
||||
example_q.update(results, list_examples)
|
||||
logger.report(elapsed_seconds=time.time() - start_time)
|
||||
|
||||
|
||||
def run_comparison(
|
||||
base_config_path: str,
|
||||
criteria_list: str = "out_random,in_random,out_similar,in_similar",
|
||||
embedding_path: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run experiments comparing multiple selection criteria.
|
||||
|
||||
Args:
|
||||
base_config_path: Path to base YAML config file
|
||||
criteria_list: Comma-separated list of selection criteria to compare
|
||||
embedding_path: Path to embeddings (required for *_similar criteria)
|
||||
|
||||
Example:
|
||||
python run.py compare config/sleepedf.yaml \\
|
||||
--criteria_list="out_random,out_similar" \\
|
||||
--embedding_path="./embeddings_full"
|
||||
"""
|
||||
|
||||
base_config = yaml.load(open(base_config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||
criteria = [c.strip() for c in criteria_list.split(",")]
|
||||
|
||||
print("=" * 60)
|
||||
print("COMPARISON EXPERIMENT")
|
||||
print("=" * 60)
|
||||
print(f" Selection criteria to compare: {criteria}")
|
||||
print(f" Embedding path: {embedding_path}")
|
||||
print("=" * 60)
|
||||
|
||||
for criterion in criteria:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Running experiment: {criterion}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
config = base_config.copy()
|
||||
config["selection_criteria"] = criterion
|
||||
|
||||
base_log_path = config["log_path"]
|
||||
if not base_log_path.endswith(criterion):
|
||||
config["log_path"] = f"{base_log_path}_{criterion}"
|
||||
|
||||
if criterion in ["out_similar", "in_similar"]:
|
||||
if embedding_path:
|
||||
config["embedding_path"] = embedding_path
|
||||
else:
|
||||
print(f"[WARNING] {criterion} requires --embedding_path argument")
|
||||
continue
|
||||
|
||||
model_pool = load_models(config["models"])
|
||||
|
||||
kwargs_list = asyncio.run(load_data_parallel(config))
|
||||
|
||||
if len(kwargs_list) == 0:
|
||||
print(f"[WARNING] No tasks for {criterion}, skipping")
|
||||
continue
|
||||
|
||||
model_pool.warmup()
|
||||
asyncio.run(run_parallel(kwargs_list, model_pool, config))
|
||||
|
||||
print(f"[DONE] Results for {criterion} saved to: {config['log_path']}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL COMPARISON EXPERIMENTS COMPLETED")
|
||||
print("=" * 60)
|
||||
def main(config_path: str):
|
||||
asyncio.run(run(config_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire({
|
||||
"run": run,
|
||||
"compare": run_comparison,
|
||||
})
|
||||
Fire(main)
|
||||
|
||||
@@ -1,412 +0,0 @@
|
||||
"""
|
||||
Self-Consistency Results Analysis Script
|
||||
|
||||
This module provides analysis tools for Self-Consistency experiment results:
|
||||
- Load and aggregate results from experiment directories
|
||||
- Compute detailed statistics and metrics
|
||||
- Generate visualizations and reports
|
||||
- Compare results across different configurations
|
||||
|
||||
Usage:
|
||||
# Analyze single experiment
|
||||
python -m sc.analysis.analyze_sc_results analyze /path/to/results
|
||||
|
||||
# Compare multiple experiments
|
||||
python -m sc.analysis.analyze_sc_results compare /path/to/exp1 /path/to/exp2
|
||||
|
||||
# Generate summary report
|
||||
python -m sc.analysis.analyze_sc_results report /path/to/results --output report.md
|
||||
|
||||
Author: NMSL Research Team
|
||||
Date: 2026-01-21
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import yaml
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from glob import glob
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
from fire import Fire
|
||||
|
||||
|
||||
class SCResultsAnalyzer:
|
||||
"""
|
||||
Analyzer for Self-Consistency experiment results.
|
||||
|
||||
Loads experiment results and provides various analysis methods
|
||||
including accuracy computation, consistency analysis, and
|
||||
per-class/per-user breakdowns.
|
||||
|
||||
Attributes:
|
||||
results_path: Path to experiment results directory
|
||||
results: List of result dictionaries
|
||||
config: Experiment configuration
|
||||
stats: Computed statistics
|
||||
"""
|
||||
|
||||
def __init__(self, results_path: str):
|
||||
"""
|
||||
Initialize analyzer with results directory.
|
||||
|
||||
Args:
|
||||
results_path: Path to experiment results directory
|
||||
Should contain all_results.json and statistics.json
|
||||
"""
|
||||
self.results_path = results_path
|
||||
self.results = []
|
||||
self.config = {}
|
||||
self.stats = {}
|
||||
|
||||
self._load_results()
|
||||
|
||||
def _load_results(self) -> None:
|
||||
"""Load results, config, and statistics from files."""
|
||||
# Load all results
|
||||
results_file = os.path.join(self.results_path, "all_results.json")
|
||||
if os.path.exists(results_file):
|
||||
with open(results_file, "r", encoding="utf-8") as f:
|
||||
self.results = json.load(f)
|
||||
print(f"[LOAD] Loaded {len(self.results)} results from {results_file}")
|
||||
else:
|
||||
print(f"[WARNING] Results file not found: {results_file}")
|
||||
|
||||
# Load config
|
||||
config_file = os.path.join(self.results_path, "config.yaml")
|
||||
if os.path.exists(config_file):
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
print(f"[LOAD] Loaded config from {config_file}")
|
||||
|
||||
# Load pre-computed statistics
|
||||
stats_file = os.path.join(self.results_path, "statistics.json")
|
||||
if os.path.exists(stats_file):
|
||||
with open(stats_file, "r", encoding="utf-8") as f:
|
||||
self.stats = json.load(f)
|
||||
print(f"[LOAD] Loaded statistics from {stats_file}")
|
||||
|
||||
def get_dataframe(self) -> pd.DataFrame:
|
||||
"""
|
||||
Convert results to pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
DataFrame with one row per result
|
||||
"""
|
||||
if not self.results:
|
||||
return pd.DataFrame()
|
||||
|
||||
return pd.DataFrame(self.results)
|
||||
|
||||
def compute_accuracy(self) -> Dict[str, float]:
|
||||
"""
|
||||
Compute overall and per-class accuracy.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'overall' and per-class accuracies
|
||||
"""
|
||||
if not self.results:
|
||||
return {"overall": 0.0}
|
||||
|
||||
df = self.get_dataframe()
|
||||
|
||||
# Overall accuracy
|
||||
overall = df["is_correct"].mean()
|
||||
|
||||
# Per-class accuracy
|
||||
accuracy = {"overall": overall}
|
||||
for cls in df["ground_truth"].unique():
|
||||
cls_df = df[df["ground_truth"] == cls]
|
||||
accuracy[cls] = cls_df["is_correct"].mean()
|
||||
|
||||
return accuracy
|
||||
|
||||
def compute_consistency_analysis(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze relationship between consistency and accuracy.
|
||||
|
||||
Returns:
|
||||
Dictionary with consistency statistics and correlations
|
||||
"""
|
||||
if not self.results:
|
||||
return {}
|
||||
|
||||
df = self.get_dataframe()
|
||||
|
||||
# Consistency distribution
|
||||
consistency_mean = df["consistency"].mean()
|
||||
consistency_std = df["consistency"].std()
|
||||
|
||||
# High consistency accuracy
|
||||
high_cons = df[df["consistency"] >= 0.8]
|
||||
low_cons = df[df["consistency"] < 0.8]
|
||||
|
||||
high_cons_acc = high_cons["is_correct"].mean() if len(high_cons) > 0 else 0
|
||||
low_cons_acc = low_cons["is_correct"].mean() if len(low_cons) > 0 else 0
|
||||
|
||||
# Consistency bins
|
||||
bins = [0.0, 0.4, 0.6, 0.8, 1.0]
|
||||
labels = ["0.0-0.4", "0.4-0.6", "0.6-0.8", "0.8-1.0"]
|
||||
df["consistency_bin"] = pd.cut(df["consistency"], bins=bins, labels=labels)
|
||||
|
||||
bin_stats = {}
|
||||
for label in labels:
|
||||
bin_df = df[df["consistency_bin"] == label]
|
||||
bin_stats[label] = {
|
||||
"count": len(bin_df),
|
||||
"accuracy": bin_df["is_correct"].mean() if len(bin_df) > 0 else 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"consistency_mean": consistency_mean,
|
||||
"consistency_std": consistency_std,
|
||||
"high_consistency_count": len(high_cons),
|
||||
"high_consistency_accuracy": high_cons_acc,
|
||||
"low_consistency_count": len(low_cons),
|
||||
"low_consistency_accuracy": low_cons_acc,
|
||||
"bin_statistics": bin_stats,
|
||||
}
|
||||
|
||||
def compute_per_user_accuracy(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Compute accuracy breakdown by user.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping user_id to accuracy metrics
|
||||
"""
|
||||
if not self.results:
|
||||
return {}
|
||||
|
||||
df = self.get_dataframe()
|
||||
|
||||
user_stats = {}
|
||||
for user_id in df["user_id"].unique():
|
||||
user_df = df[df["user_id"] == user_id]
|
||||
user_stats[user_id] = {
|
||||
"count": len(user_df),
|
||||
"accuracy": user_df["is_correct"].mean(),
|
||||
"avg_consistency": user_df["consistency"].mean(),
|
||||
"avg_confidence": user_df["confidence"].mean(),
|
||||
}
|
||||
|
||||
return user_stats
|
||||
|
||||
def compute_confusion_matrix(self) -> Tuple[np.ndarray, List[str]]:
|
||||
"""
|
||||
Compute confusion matrix.
|
||||
|
||||
Returns:
|
||||
Tuple of (confusion_matrix, class_labels)
|
||||
"""
|
||||
if not self.results:
|
||||
return np.array([]), []
|
||||
|
||||
df = self.get_dataframe()
|
||||
classes = sorted(df["ground_truth"].unique())
|
||||
|
||||
matrix = np.zeros((len(classes), len(classes)), dtype=int)
|
||||
class_to_idx = {cls: i for i, cls in enumerate(classes)}
|
||||
|
||||
for _, row in df.iterrows():
|
||||
gt_idx = class_to_idx[row["ground_truth"]]
|
||||
pred = row["answer"]
|
||||
if pred in class_to_idx:
|
||||
pred_idx = class_to_idx[pred]
|
||||
matrix[gt_idx, pred_idx] += 1
|
||||
|
||||
return matrix, classes
|
||||
|
||||
def generate_report(self) -> str:
|
||||
"""
|
||||
Generate comprehensive markdown report.
|
||||
|
||||
Returns:
|
||||
Markdown formatted report string
|
||||
"""
|
||||
report = []
|
||||
report.append("# Self-Consistency Experiment Results\n")
|
||||
|
||||
# Config summary
|
||||
report.append("## Experiment Configuration\n")
|
||||
if self.config:
|
||||
report.append(f"- Selection Criteria: {self.config.get('selection_criteria', 'N/A')}")
|
||||
report.append(f"- Num ICL Examples: {self.config.get('num_examples', 'N/A')}")
|
||||
report.append(f"- Num SC Samples: {self.config.get('num_sc_samples', 'N/A')}")
|
||||
report.append(f"- Temperature: {self.config.get('temperature', 'N/A')}")
|
||||
report.append(f"- Num Seeds: {self.config.get('num_seeds', 'N/A')}")
|
||||
report.append("")
|
||||
|
||||
# Overall statistics
|
||||
report.append("## Overall Statistics\n")
|
||||
accuracy = self.compute_accuracy()
|
||||
report.append(f"- **Overall Accuracy**: {accuracy['overall']:.4f}")
|
||||
report.append(f"- **Total Samples**: {len(self.results)}")
|
||||
|
||||
if self.stats:
|
||||
report.append(f"- **Avg Confidence**: {self.stats.get('avg_confidence', 0):.4f}")
|
||||
report.append(f"- **Avg Consistency**: {self.stats.get('avg_consistency', 0):.4f}")
|
||||
report.append("")
|
||||
|
||||
# Per-class accuracy
|
||||
report.append("## Per-Class Accuracy\n")
|
||||
report.append("| Class | Accuracy |")
|
||||
report.append("|-------|----------|")
|
||||
for cls, acc in sorted(accuracy.items()):
|
||||
if cls != "overall":
|
||||
report.append(f"| {cls} | {acc:.4f} |")
|
||||
report.append("")
|
||||
|
||||
# Consistency analysis
|
||||
report.append("## Consistency Analysis\n")
|
||||
cons_analysis = self.compute_consistency_analysis()
|
||||
if cons_analysis:
|
||||
report.append(f"- **Mean Consistency**: {cons_analysis['consistency_mean']:.4f}")
|
||||
report.append(f"- **Consistency Std**: {cons_analysis['consistency_std']:.4f}")
|
||||
report.append(f"- **High Consistency (≥0.8) Accuracy**: {cons_analysis['high_consistency_accuracy']:.4f}")
|
||||
report.append(f"- **Low Consistency (<0.8) Accuracy**: {cons_analysis['low_consistency_accuracy']:.4f}")
|
||||
|
||||
report.append("\n### Accuracy by Consistency Bin\n")
|
||||
report.append("| Consistency Range | Count | Accuracy |")
|
||||
report.append("|-------------------|-------|----------|")
|
||||
for bin_label, stats in cons_analysis["bin_statistics"].items():
|
||||
report.append(f"| {bin_label} | {stats['count']} | {stats['accuracy']:.4f} |")
|
||||
report.append("")
|
||||
|
||||
# Confusion matrix
|
||||
report.append("## Confusion Matrix\n")
|
||||
matrix, classes = self.compute_confusion_matrix()
|
||||
if len(classes) > 0:
|
||||
header = "| | " + " | ".join(classes) + " |"
|
||||
separator = "|---" * (len(classes) + 1) + "|"
|
||||
report.append(header)
|
||||
report.append(separator)
|
||||
for i, cls in enumerate(classes):
|
||||
row = f"| **{cls}** | " + " | ".join(str(x) for x in matrix[i]) + " |"
|
||||
report.append(row)
|
||||
report.append("")
|
||||
|
||||
return "\n".join(report)
|
||||
|
||||
def save_report(self, output_path: str) -> None:
|
||||
"""
|
||||
Save report to file.
|
||||
|
||||
Args:
|
||||
output_path: Path to save the report
|
||||
"""
|
||||
report = self.generate_report()
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(report)
|
||||
print(f"[SAVE] Report saved to: {output_path}")
|
||||
|
||||
|
||||
def compare_experiments(
|
||||
paths: List[str],
|
||||
output_path: Optional[str] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Compare results across multiple experiments.
|
||||
|
||||
Args:
|
||||
paths: List of paths to experiment result directories
|
||||
output_path: Optional path to save comparison CSV
|
||||
|
||||
Returns:
|
||||
DataFrame with comparison metrics
|
||||
"""
|
||||
comparison = []
|
||||
|
||||
for path in paths:
|
||||
analyzer = SCResultsAnalyzer(path)
|
||||
accuracy = analyzer.compute_accuracy()
|
||||
cons = analyzer.compute_consistency_analysis()
|
||||
|
||||
comparison.append({
|
||||
"experiment": os.path.basename(path),
|
||||
"selection_criteria": analyzer.config.get("selection_criteria", "N/A"),
|
||||
"num_samples": len(analyzer.results),
|
||||
"accuracy": accuracy["overall"],
|
||||
"avg_consistency": cons.get("consistency_mean", 0),
|
||||
"high_cons_accuracy": cons.get("high_consistency_accuracy", 0),
|
||||
})
|
||||
|
||||
df = pd.DataFrame(comparison)
|
||||
|
||||
if output_path:
|
||||
df.to_csv(output_path, index=False)
|
||||
print(f"[SAVE] Comparison saved to: {output_path}")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI Commands
|
||||
# =============================================================================
|
||||
|
||||
def analyze(results_path: str) -> None:
|
||||
"""
|
||||
Analyze single experiment results.
|
||||
|
||||
Args:
|
||||
results_path: Path to experiment results directory
|
||||
"""
|
||||
analyzer = SCResultsAnalyzer(results_path)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("EXPERIMENT ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
# Accuracy
|
||||
accuracy = analyzer.compute_accuracy()
|
||||
print(f"\nOverall Accuracy: {accuracy['overall']:.4f}")
|
||||
print("\nPer-Class Accuracy:")
|
||||
for cls, acc in sorted(accuracy.items()):
|
||||
if cls != "overall":
|
||||
print(f" {cls}: {acc:.4f}")
|
||||
|
||||
# Consistency
|
||||
cons = analyzer.compute_consistency_analysis()
|
||||
print(f"\nConsistency Analysis:")
|
||||
print(f" Mean: {cons['consistency_mean']:.4f}")
|
||||
print(f" Std: {cons['consistency_std']:.4f}")
|
||||
print(f" High (≥0.8) Accuracy: {cons['high_consistency_accuracy']:.4f}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def report(results_path: str, output: str = None) -> None:
|
||||
"""
|
||||
Generate and optionally save analysis report.
|
||||
|
||||
Args:
|
||||
results_path: Path to experiment results directory
|
||||
output: Optional path to save report (default: results_path/report.md)
|
||||
"""
|
||||
analyzer = SCResultsAnalyzer(results_path)
|
||||
|
||||
if output is None:
|
||||
output = os.path.join(results_path, "report.md")
|
||||
|
||||
analyzer.save_report(output)
|
||||
print(f"Report generated: {output}")
|
||||
|
||||
|
||||
def compare(*paths: str, output: str = None) -> None:
|
||||
"""
|
||||
Compare multiple experiment results.
|
||||
|
||||
Args:
|
||||
paths: Paths to experiment result directories
|
||||
output: Optional path to save comparison CSV
|
||||
"""
|
||||
df = compare_experiments(list(paths), output)
|
||||
print("\n" + df.to_string(index=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire({
|
||||
"analyze": analyze,
|
||||
"report": report,
|
||||
"compare": compare,
|
||||
})
|
||||
@@ -1,73 +0,0 @@
|
||||
# ==============================================================================
|
||||
# Sleep Stage Classification - Confidence-based Queue Policy Experiment
|
||||
# ==============================================================================
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Data Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
data_workers: 16
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Experiment Settings (to be overridden by CLI)
|
||||
# ------------------------------------------------------------------------------
|
||||
# These will be set via command line arguments
|
||||
user_id: 5 # Target user (5 or 10)
|
||||
shuffle_seed: 42 # Shuffle seed (42, 123, or 456)
|
||||
|
||||
# Queue and ICL settings
|
||||
queue_size: 5
|
||||
num_icl_shots: 5 # Number of ICL examples per agent
|
||||
|
||||
# Self-Consistency settings
|
||||
num_sc_samples: 8 # Number of SC sampling agents
|
||||
|
||||
# Example pool selection: "out" (different users) or "in" (same user)
|
||||
example_pool: "out"
|
||||
|
||||
# Process all samples (no sampling)
|
||||
sample_rate: 1
|
||||
|
||||
# Tracking window size for rolling window accuracy
|
||||
tracking_window: 20
|
||||
|
||||
# Model context window size
|
||||
num_ctx: 15000
|
||||
|
||||
# Temperature for LLM
|
||||
temperature: 0.0
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Queue Policy
|
||||
# ------------------------------------------------------------------------------
|
||||
queue_policy: "confidence"
|
||||
# Options: "confidence", "consistency", "random"
|
||||
# This config is for CONFIDENCE-based queue updates
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Model Configuration (8 agents)
|
||||
# ------------------------------------------------------------------------------
|
||||
models:
|
||||
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Output Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
log_path: "/mnt/sting/ssum/sleepedf_sc_experiment/confidence"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Sleep Stages
|
||||
# ------------------------------------------------------------------------------
|
||||
stages:
|
||||
- W
|
||||
- N1
|
||||
- N2
|
||||
- N3
|
||||
- REM
|
||||
@@ -1,73 +0,0 @@
|
||||
# ==============================================================================
|
||||
# Sleep Stage Classification - Consistency-based Queue Policy Experiment
|
||||
# ==============================================================================
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Data Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
data_workers: 16
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Experiment Settings (to be overridden by CLI)
|
||||
# ------------------------------------------------------------------------------
|
||||
# These will be set via command line arguments
|
||||
user_id: 5 # Target user (5 or 10)
|
||||
shuffle_seed: 42 # Shuffle seed (42, 123, or 456)
|
||||
|
||||
# Queue and ICL settings
|
||||
queue_size: 5
|
||||
num_icl_shots: 5 # Number of ICL examples per agent
|
||||
|
||||
# Self-Consistency settings
|
||||
num_sc_samples: 8 # Number of SC sampling agents
|
||||
|
||||
# Example pool selection: "out" (different users) or "in" (same user)
|
||||
example_pool: "out"
|
||||
|
||||
# Process all samples (no sampling)
|
||||
sample_rate: 1
|
||||
|
||||
# Tracking window size for rolling window accuracy
|
||||
tracking_window: 20
|
||||
|
||||
# Model context window size
|
||||
num_ctx: 15000
|
||||
|
||||
# Temperature for LLM
|
||||
temperature: 0.0
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Queue Policy
|
||||
# ------------------------------------------------------------------------------
|
||||
queue_policy: "consistency"
|
||||
# Options: "confidence", "consistency", "random"
|
||||
# This config is for CONSISTENCY-based queue updates
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Model Configuration (8 agents)
|
||||
# ------------------------------------------------------------------------------
|
||||
models:
|
||||
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Output Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
log_path: "/mnt/sting/ssum/sleepedf_sc_experiment/consistency"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Sleep Stages
|
||||
# ------------------------------------------------------------------------------
|
||||
stages:
|
||||
- W
|
||||
- N1
|
||||
- N2
|
||||
- N3
|
||||
- REM
|
||||
@@ -1,86 +0,0 @@
|
||||
# ==============================================================================
|
||||
# Sleep Stage Classification - Queue Random Baseline Experiment
|
||||
# ==============================================================================
|
||||
#
|
||||
# This is an ABLATION STUDY baseline:
|
||||
# - Queue structure is maintained (size=5)
|
||||
# - BUT all 5 elements are refreshed with random samples EVERY step
|
||||
# - No cumulative learning or retention of good examples
|
||||
#
|
||||
# Purpose: Test whether performance gains come from:
|
||||
# 1. Queue structure itself (using 5 ICL examples)
|
||||
# 2. Cumulative learning (retaining high-quality examples over time)
|
||||
#
|
||||
# Expected Result:
|
||||
# - If Queue Random ≈ Confidence/Consistency → Queue structure is the key
|
||||
# - If Queue Random < Confidence/Consistency → Cumulative learning is the key
|
||||
# ==============================================================================
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Data Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
data_workers: 16
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Experiment Settings (to be overridden by CLI)
|
||||
# ------------------------------------------------------------------------------
|
||||
user_id: 5 # Target user (5, 10, or 15)
|
||||
shuffle_seed: 42 # Shuffle seed (42 or 123)
|
||||
|
||||
# Queue settings (same as other policies for fair comparison)
|
||||
queue_size: 5 # Number of ICL example sets
|
||||
num_icl_shots: 5 # Number of ICL examples per agent
|
||||
|
||||
# Self-Consistency settings
|
||||
num_sc_samples: 8 # Number of SC sampling agents (same as other policies)
|
||||
|
||||
# Example pool selection: "out" (different users) or "in" (same user)
|
||||
example_pool: "out"
|
||||
|
||||
# Process all samples (no sampling)
|
||||
sample_rate: 1
|
||||
|
||||
# Tracking window size for rolling window accuracy
|
||||
tracking_window: 20
|
||||
|
||||
# Model context window size
|
||||
num_ctx: 15000
|
||||
|
||||
# Temperature for LLM
|
||||
temperature: 0.0
|
||||
|
||||
# Sleep stages for classification
|
||||
stages:
|
||||
- "W"
|
||||
- "N1"
|
||||
- "N2"
|
||||
- "N3"
|
||||
- "REM"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Queue Policy
|
||||
# ------------------------------------------------------------------------------
|
||||
queue_policy: "queue_random"
|
||||
# This config is for QUEUE RANDOM baseline:
|
||||
# - Queue structure exists (5 slots)
|
||||
# - ALL slots are refreshed with random samples every step
|
||||
# - No retention of good examples
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Output Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
log_path: "/mnt/sting/ssum/sleepedf_sc_experiment/queue_random"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Model Configuration (8 agents for Self-Consistency)
|
||||
# ------------------------------------------------------------------------------
|
||||
models:
|
||||
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
|
||||
@@ -1,85 +0,0 @@
|
||||
# ==============================================================================
|
||||
# Sleep-EDF Self-Consistency Experiment Configuration
|
||||
# ==============================================================================
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Data Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
data_workers: 16
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Experiment Settings
|
||||
# ------------------------------------------------------------------------------
|
||||
num_seeds: 1
|
||||
num_examples: 1
|
||||
|
||||
# Sample rate: process every Nth sample for faster experiments
|
||||
sample_rate: 10
|
||||
|
||||
# Example pool selection: "out" (different users) or "in" (same user)
|
||||
example_pool: "out"
|
||||
|
||||
# Continuous mode: if True, process samples in order; if False, shuffle
|
||||
continuous: true
|
||||
|
||||
# Queue size for example selection (capacity of the example queue)
|
||||
queue_size: 5
|
||||
|
||||
# Model context window size
|
||||
num_ctx: 15000
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Selection Criteria
|
||||
# Available options:
|
||||
# - out_random: Random selection from different users (baseline)
|
||||
# - in_random: Random selection from same user (personalization baseline)
|
||||
# - out_similar: Chronos-2 embedding similarity-based selection
|
||||
# - out_metadata: Gower distance-based selection (gender, age)
|
||||
# ------------------------------------------------------------------------------
|
||||
selection_criteria: "out_random"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# out_similar Configuration (Chronos-2 Embedding)
|
||||
# Uncomment when using out_similar selection criteria
|
||||
# ------------------------------------------------------------------------------
|
||||
# embedding_path: "/home/ssum/tsllm_personalization_icl/embeddings_full"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# out_metadata Configuration (Gower Distance)
|
||||
# Uncomment when using out_metadata selection criteria
|
||||
# ------------------------------------------------------------------------------
|
||||
# metadata_path: "/home/ssum/tsllm_personalization_icl/preprocess/SC-subjects.xls"
|
||||
# weight_gender: 1.0 # Gender distance weight (same=0, different=1)
|
||||
# weight_age: 1.0 # Age distance weight (normalized: |age1-age2|/range)
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Self-Consistency Settings
|
||||
# ------------------------------------------------------------------------------
|
||||
# Number of sampling iterations for Self-Consistency
|
||||
num_sc_samples: 5
|
||||
temperature: 0.0
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Model Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
# Multiple Ollama instances provide model diversity even with T=0
|
||||
models:
|
||||
- ollama:url:iu.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11441/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Output Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
log_path: "/mnt/sting/ssum/sleepedf_sc_result_test"
|
||||
|
||||
# Previous experiment result paths (for reference):
|
||||
# chronos_base: "/mnt/sting/ssum/sleepedf_chronos_base_result"
|
||||
# out_random: "/mnt/sting/ssum/sleepedf_chronos_result_outrandom"
|
||||
# in_random: "/mnt/sting/ssum/sleepedf_chronos_result"
|
||||
147
sc/core/agent.py
147
sc/core/agent.py
@@ -1,147 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
model_pool,
|
||||
log_path,
|
||||
):
|
||||
self.name = name
|
||||
self.model_pool = model_pool
|
||||
self.log_path = log_path
|
||||
self.root_log_path = log_path
|
||||
self.agent_log_path = os.path.join(log_path, name)
|
||||
os.makedirs(self.agent_log_path, exist_ok=True)
|
||||
|
||||
self.long_term_memory = []
|
||||
self.short_term_memory = []
|
||||
self.volatile_memory = []
|
||||
|
||||
def log(self, message, local=True):
|
||||
path = os.path.join(self.root_log_path, "log.txt")
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
message_type = "UNKNOWN"
|
||||
if isinstance(message, SystemMessage):
|
||||
message_type = "SYSTEM"
|
||||
if isinstance(message, HumanMessage):
|
||||
message_type = "HUMAN"
|
||||
if isinstance(message, AIMessage):
|
||||
message_type = "AI"
|
||||
content = message.content.strip()
|
||||
name = self.name
|
||||
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
|
||||
if local:
|
||||
local_path = os.path.join(self.agent_log_path, "log.txt")
|
||||
with open(local_path, "a", encoding="utf-8") as f:
|
||||
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
|
||||
|
||||
def update_memory(self):
|
||||
self.long_term_memory.extend(self.short_term_memory)
|
||||
self.clean_short_term_memory()
|
||||
self.clean_volatile_memory()
|
||||
|
||||
def clean_short_term_memory(self):
|
||||
self.short_term_memory = []
|
||||
|
||||
def clean_volatile_memory(self):
|
||||
self.volatile_memory = []
|
||||
|
||||
def clean_long_term_memory(self):
|
||||
self.long_term_memory = []
|
||||
|
||||
def clean_json_text(self, text):
|
||||
text = text.strip()
|
||||
text = text.replace("‘", "'").replace("’", "'")
|
||||
text = text.replace("“", "'").replace("”", "'")
|
||||
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text)
|
||||
text = re.sub(r",\s*}", "}", text)
|
||||
text = re.sub(r",\s*]", "]", text)
|
||||
text = "".join(ch for ch in text if ch.isprintable())
|
||||
text = text.replace("][", ",")
|
||||
return text
|
||||
|
||||
def safe_parse_json(self, text):
|
||||
if not text:
|
||||
return None
|
||||
text = text.strip()
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
elif not text.endswith("}"):
|
||||
text += "}"
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
print("[!] JSON parse failed")
|
||||
return None
|
||||
|
||||
async def validate_response(self, response, fields, volatile=False):
|
||||
if (
|
||||
not response
|
||||
or not isinstance(response, dict)
|
||||
or not all(field in response for field in fields)
|
||||
):
|
||||
print("[!] The JSON failed to be parsed. Trying again.")
|
||||
content = (
|
||||
"Failed to parse the JSON from the previous response. Please try again."
|
||||
)
|
||||
response = await self.invoke(content, volatile=volatile)
|
||||
response = self.safe_parse_json(response)
|
||||
if (
|
||||
not response
|
||||
or not isinstance(response, dict)
|
||||
or not all(field in response for field in fields)
|
||||
):
|
||||
print("[!] Retry failed.")
|
||||
return None
|
||||
return response
|
||||
|
||||
def get_last_response(self):
|
||||
if len(self.long_term_memory) >= 2:
|
||||
last_msg = self.long_term_memory[-1]
|
||||
if isinstance(last_msg, AIMessage):
|
||||
return self.safe_parse_json(last_msg.content)
|
||||
return None
|
||||
|
||||
def set_system_message(self, content, local=True):
|
||||
system_message = SystemMessage(content=content)
|
||||
self.log(system_message, local)
|
||||
self.long_term_memory.append(system_message)
|
||||
|
||||
async def invoke(self, content, volatile=False, local=True):
|
||||
messages = self.long_term_memory.copy()
|
||||
if volatile:
|
||||
messages.extend(self.volatile_memory)
|
||||
else:
|
||||
messages.extend(self.short_term_memory)
|
||||
messages.append(HumanMessage(content=content))
|
||||
try:
|
||||
response = await self.model_pool.invoke(messages)
|
||||
if volatile:
|
||||
self.volatile_memory.extend([HumanMessage(content=content), response])
|
||||
else:
|
||||
self.short_term_memory.extend([HumanMessage(content=content), response])
|
||||
local_ = not volatile and local
|
||||
self.log(HumanMessage(content=content), local=local_)
|
||||
self.log(response, local=local_)
|
||||
return response.content.strip()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
print(f"[Error] Error occurred while invoking LLM: {e}")
|
||||
@@ -1,137 +0,0 @@
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
|
||||
class AgentPool:
|
||||
def __init__(self, log_path):
|
||||
self.agents = {}
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
self.log_path = log_path
|
||||
|
||||
def add_agent(self, agent):
|
||||
self.agents[agent.index] = agent
|
||||
|
||||
def log_summary(self, message, print_log=True):
|
||||
path = os.path.join(self.log_path, "summary.txt")
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(f"{message}\n")
|
||||
if print_log:
|
||||
print(message)
|
||||
|
||||
def get_last_responses(self):
|
||||
responses = {}
|
||||
for index, agent in self.agents.items():
|
||||
response = agent.get_last_response()
|
||||
if response:
|
||||
responses[index] = response
|
||||
return responses
|
||||
|
||||
def vote(self, responses, mode="majority_vote"):
|
||||
"""
|
||||
Vote for the final answer from agent responses.
|
||||
|
||||
Args:
|
||||
responses: Dictionary mapping agent indices to response dicts
|
||||
Format: {"agent_name": {"ANSWER": "...", "CONFIDENCE": 0.8, ...}, ...}
|
||||
mode: Voting mode - "majority_vote", "highest_confidence", or "confidence_vote"
|
||||
|
||||
Returns:
|
||||
The winning answer string
|
||||
"""
|
||||
if not responses:
|
||||
return None
|
||||
|
||||
if mode == "highest_confidence":
|
||||
# Find response with highest confidence
|
||||
best_response = max(responses.values(), key=lambda x: x.get("CONFIDENCE", 0))
|
||||
return best_response.get("ANSWER")
|
||||
elif mode == "majority_vote":
|
||||
# Count votes for each answer
|
||||
answer_cnt = {}
|
||||
for response in responses.values():
|
||||
answer = response.get("ANSWER")
|
||||
if answer:
|
||||
answer_cnt[answer] = answer_cnt.get(answer, 0) + 1
|
||||
if not answer_cnt:
|
||||
return None
|
||||
max_val = max(answer_cnt.values())
|
||||
max_key = [k for k, v in answer_cnt.items() if v == max_val][0]
|
||||
return max_key
|
||||
elif mode == "confidence_vote":
|
||||
# Weight votes by confidence
|
||||
cnts = {}
|
||||
for response in responses.values():
|
||||
answer = response.get("ANSWER")
|
||||
confidence = response.get("CONFIDENCE", 0)
|
||||
if answer:
|
||||
cnts[answer] = cnts.get(answer, 0) + confidence
|
||||
if not cnts:
|
||||
return None
|
||||
return max(cnts, key=cnts.get)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
async def run_parallel_interpretation(self):
|
||||
tasks = []
|
||||
for _, agent in self.agents.items():
|
||||
tasks.append(asyncio.create_task(agent.interpret()))
|
||||
results = await asyncio.gather(*tasks)
|
||||
print(results)
|
||||
if None in results:
|
||||
self.log_summary(f"[Error] Failed to interpret")
|
||||
return None
|
||||
|
||||
for _, agent in self.agents.items():
|
||||
agent.update_memory()
|
||||
|
||||
responses = self.get_last_responses()
|
||||
for index, response in responses.items():
|
||||
answer = response.get("ANSWER", "UNKNOWN")
|
||||
confidence = response.get("CONFIDENCE", 0.0)
|
||||
self.log_summary(f"[Interpretation] <{index}> provided answer: {answer} with confidence: {confidence}")
|
||||
voted_result = self.vote(responses, mode="majority_vote")
|
||||
queue_idcs = self.filter_examples(responses)
|
||||
|
||||
# Calculate confidence and consistency for the voted result
|
||||
all_answers = [r.get("ANSWER") for r in responses.values()]
|
||||
majority_responses = [r for r in responses.values() if r.get("ANSWER") == voted_result]
|
||||
|
||||
# Avg confidence of majority answer
|
||||
avg_confidence = sum(r.get("CONFIDENCE", 0) for r in majority_responses) / len(majority_responses) if majority_responses else 0
|
||||
|
||||
# Consistency = ratio of majority votes
|
||||
consistency = len(majority_responses) / len(responses) if responses else 0
|
||||
|
||||
return voted_result, queue_idcs, avg_confidence, consistency, responses
|
||||
|
||||
def filter_examples(self, responses):
|
||||
"""
|
||||
Filter examples based on highest confidence responses.
|
||||
|
||||
Args:
|
||||
responses: Dictionary mapping agent indices to response dicts
|
||||
Format: {"agent_name": {"ANSWER": "...", "CONFIDENCE": 0.8, "_example_idx": 0}, ...}
|
||||
|
||||
Returns:
|
||||
List of queue indices with highest confidence
|
||||
"""
|
||||
queue_idcs = []
|
||||
max_confidence = 0
|
||||
for index, response in responses.items():
|
||||
confidence = response.get("CONFIDENCE", 0)
|
||||
if confidence > max_confidence:
|
||||
max_confidence = confidence
|
||||
queue_idcs = [index]
|
||||
elif confidence == max_confidence:
|
||||
queue_idcs.append(index)
|
||||
|
||||
# Debug: 선택 이유 출력
|
||||
print(f"\n[Selection] Confidence Summary:")
|
||||
for index, response in sorted(responses.items()):
|
||||
conf = response.get("CONFIDENCE", 0)
|
||||
ans = response.get("ANSWER", "?")
|
||||
marker = " ★ SELECTED" if index in queue_idcs else ""
|
||||
print(f" Case #{index}: {ans} (conf: {conf}){marker}")
|
||||
print(f"[Selection] Max Confidence: {max_confidence} → Selected Case(s): {queue_idcs}")
|
||||
|
||||
return queue_idcs
|
||||
@@ -1,72 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import datasets
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
|
||||
class DataLoader:
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
user_id,
|
||||
example_pool="out",
|
||||
continuous=True,
|
||||
):
|
||||
if not os.path.exists(os.path.join(data_path, "info.json")):
|
||||
return
|
||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
|
||||
return
|
||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
|
||||
return
|
||||
|
||||
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
|
||||
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
|
||||
self.example_dataset = datasets.Dataset.from_list([])
|
||||
users = glob(os.path.join(data_path, "*"))
|
||||
users = [path.split("/")[-1] for path in users]
|
||||
if "info.json" in users:
|
||||
users.remove("info.json")
|
||||
for user in users:
|
||||
if example_pool == "out" and user == user_id:
|
||||
continue
|
||||
if example_pool == "in" and user != user_id:
|
||||
continue
|
||||
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
|
||||
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
|
||||
|
||||
if not continuous:
|
||||
self.test_dataset = self.test_dataset.shuffle(seed=0)
|
||||
self.example_dataset = self.example_dataset.shuffle(seed=0)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.test_dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.test_dataset[idx]
|
||||
return sample
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.test_dataset:
|
||||
yield sample
|
||||
|
||||
def get_examples(self):
|
||||
return self.example_dataset
|
||||
|
||||
def get_metadata(self):
|
||||
return self.metadata
|
||||
|
||||
def get_sensor_info(self):
|
||||
return self.metadata["feature"]
|
||||
|
||||
def get_task_info(self):
|
||||
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
|
||||
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
|
||||
classes_info = "\n".join(classes_info)
|
||||
task_info += f"**Classes**:\n{classes_info}"
|
||||
return task_info
|
||||
|
||||
def get_classes_info(self):
|
||||
classes_info = [k for k in self.metadata["class"].keys()]
|
||||
return classes_info
|
||||
@@ -1,162 +0,0 @@
|
||||
"""
|
||||
Self-Consistency Agent for Sleep Stage Classification
|
||||
|
||||
This module implements an agent that uses Self-Consistency methodology:
|
||||
- Sample N times with the same prompt (configurable temperature)
|
||||
- Output REASON, CONFIDENCE, ANSWER in JSON format
|
||||
- Aggregate final answer via Majority Voting
|
||||
|
||||
The agent extends the base Agent class and adds Self-Consistency
|
||||
specific functionality including multi-sampling and voting.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
from sc.core.agent import Agent
|
||||
|
||||
|
||||
class JudgeAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
index: int,
|
||||
model_pool,
|
||||
task_info: str,
|
||||
classes_info: List[str],
|
||||
sensor_info: str,
|
||||
sample: Dict[str, Any],
|
||||
examples: List[Dict[str, Any]],
|
||||
log_path: str,
|
||||
):
|
||||
super().__init__(
|
||||
name=name + f"_{index}",
|
||||
model_pool=model_pool,
|
||||
log_path=log_path,
|
||||
)
|
||||
|
||||
self.task_info = task_info
|
||||
self.classes_info = classes_info
|
||||
self.sensor_info = sensor_info
|
||||
self.sample = sample
|
||||
self.examples = examples
|
||||
self.index = index # Store index for later retrieval
|
||||
self._init_system_message()
|
||||
|
||||
def _init_system_message(self) -> None:
|
||||
content = (
|
||||
f"You are a {self.name} agent that judges the answers of other agents.\n"
|
||||
f"You have the following information about the task:\n"
|
||||
f"{self.task_info}\n\n"
|
||||
f"You have the following information about the sensor data:\n"
|
||||
f"{self.sensor_info}\n\n"
|
||||
"Your goal is to analyze the features and "
|
||||
"provide a reasoned answer with your confidence level."
|
||||
)
|
||||
self.set_system_message(content)
|
||||
|
||||
def _format_feature(self, value: Any) -> str:
|
||||
"""
|
||||
Format a feature value for display.
|
||||
|
||||
Uses scientific notation for very large or very small numbers,
|
||||
and two decimal places for regular floats.
|
||||
|
||||
Args:
|
||||
value: Feature value to format
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if isinstance(value, float):
|
||||
if abs(value) >= 1e4 or (abs(value) < 1e-2 and value != 0):
|
||||
return f"{value:.2e}"
|
||||
return f"{value:.2f}"
|
||||
return str(value)
|
||||
|
||||
def _gen_example_info(self) -> str:
|
||||
"""
|
||||
Generate string representation of ICL examples.
|
||||
|
||||
Returns:
|
||||
Formatted string with example features and labels,
|
||||
or empty string if no examples provided
|
||||
"""
|
||||
if not self.examples or len(self.examples) == 0:
|
||||
return ""
|
||||
|
||||
example_info = (
|
||||
"**Examples**\n"
|
||||
"Sensor values might not always align with your inherent "
|
||||
"knowledge due to differences in data collection or processing. "
|
||||
"So, we included a few labeled examples to help your interpretation:\n"
|
||||
)
|
||||
|
||||
for example in self.examples:
|
||||
example_info += f"*Example of {example['label']}*:\n"
|
||||
for k, v in example["features"].items():
|
||||
example_info += f" - {k}: {self._format_feature(v)}\n"
|
||||
example_info += "\n"
|
||||
|
||||
return example_info.strip()
|
||||
|
||||
def _gen_feature_info(self) -> str:
|
||||
"""
|
||||
Generate string representation of current sample features.
|
||||
|
||||
Combines ICL example information with current sample features
|
||||
to provide full context for classification.
|
||||
|
||||
Returns:
|
||||
Formatted string with example and sample features
|
||||
"""
|
||||
feature_info = f"{self.name} features:\n"
|
||||
|
||||
# Add ICL example information if available
|
||||
example_info = self._gen_example_info()
|
||||
if example_info:
|
||||
feature_info += f"{example_info}\n\n"
|
||||
|
||||
# Add current sample features
|
||||
feature_info += "**Current sample features**:\n"
|
||||
for k, v in self.sample["features"].items():
|
||||
feature_info += f" - {k}: {self._format_feature(v)}\n"
|
||||
|
||||
return feature_info.strip()
|
||||
|
||||
async def interpret(self) -> str:
|
||||
feature_info = self._gen_feature_info()
|
||||
prompt = (
|
||||
f"You have received sensor features from {self.name} modality:\n"
|
||||
f"{feature_info}\n\n"
|
||||
f"Please provide your answer for the task among {self.classes_info} "
|
||||
"and the reasoning for your answer.\n"
|
||||
"Also, please provide your confidence level for the answer as a float between 0.0 and 1.0.\n"
|
||||
"Note that the sensor features might be wrong due to the data collection or processing.\n"
|
||||
"You can evaluate the quality of the features by checking the examples you have.\n\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
' "REASON": "<Your detailed reasoning for the classification>",\n'
|
||||
' "CONFIDENCE": <Your confidence as a float between 0.0 and 1.0>,\n'
|
||||
f' "ANSWER": "<Your answer among {self.classes_info}>"\n'
|
||||
"}\n\n"
|
||||
"Do not include any additional text outside the JSON."
|
||||
)
|
||||
|
||||
response = await self.invoke(prompt)
|
||||
parsed_response = self.safe_parse_json(response)
|
||||
parsed_response = await self.validate_response(
|
||||
parsed_response, ["REASON", "CONFIDENCE", "ANSWER"]
|
||||
)
|
||||
# Add example index to response for queue update
|
||||
if parsed_response:
|
||||
parsed_response["_example_idx"] = self.index
|
||||
return parsed_response
|
||||
@@ -1,180 +0,0 @@
|
||||
"""
|
||||
Majority Voting Module for Self-Consistency
|
||||
|
||||
Aggregates multiple LLM responses to determine the final answer via majority voting.
|
||||
"""
|
||||
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class VotingResult:
|
||||
answer: str # Final answer (majority vote)
|
||||
reason: str # Representative reasoning (from response with highest confidence)
|
||||
avg_confidence: float # Average confidence (for majority answer)
|
||||
consistency: float # Consistency score (ratio of majority votes)
|
||||
vote_distribution: Dict[str, int] # Vote distribution {answer: count}
|
||||
num_samples: int # Total number of valid samples
|
||||
all_responses: List[Dict[str, Any]] = field(default_factory=list) # All responses
|
||||
|
||||
@property
|
||||
def is_unanimous(self) -> bool:
|
||||
"""Check if all votes are unanimous."""
|
||||
return self.consistency == 1.0
|
||||
|
||||
@property
|
||||
def majority_count(self) -> int:
|
||||
"""Return the count of majority votes."""
|
||||
return self.vote_distribution.get(self.answer, 0)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"answer": self.answer,
|
||||
"reason": self.reason,
|
||||
"avg_confidence": self.avg_confidence,
|
||||
"consistency": self.consistency,
|
||||
"vote_distribution": self.vote_distribution,
|
||||
"num_samples": self.num_samples,
|
||||
"is_unanimous": self.is_unanimous,
|
||||
"majority_count": self.majority_count,
|
||||
}
|
||||
|
||||
|
||||
class MajorityVoting:
|
||||
def __init__(self, valid_classes: Optional[List[str]] = None):
|
||||
self.valid_classes = valid_classes
|
||||
|
||||
def aggregate(self, responses: List[Dict[str, Any]]) -> VotingResult:
|
||||
valid_responses = self._filter_valid_responses(responses) # Filter valid responses only
|
||||
|
||||
if not valid_responses:
|
||||
return self._empty_result()
|
||||
|
||||
answers = [r["ANSWER"] for r in valid_responses]
|
||||
vote_counter = Counter(answers)
|
||||
|
||||
majority_answer = self._resolve_majority(vote_counter, valid_responses)
|
||||
majority_count = vote_counter[majority_answer]
|
||||
|
||||
consistency = majority_count / len(valid_responses)
|
||||
|
||||
majority_responses = [
|
||||
r for r in valid_responses if r["ANSWER"] == majority_answer
|
||||
]
|
||||
|
||||
avg_confidence = np.mean([r["CONFIDENCE"] for r in majority_responses])
|
||||
|
||||
best_response = max(majority_responses, key=lambda x: x["CONFIDENCE"])
|
||||
best_reason = best_response.get("REASON", "")
|
||||
|
||||
return VotingResult(
|
||||
answer=majority_answer,
|
||||
reason=best_reason,
|
||||
avg_confidence=float(avg_confidence),
|
||||
consistency=float(consistency),
|
||||
vote_distribution=dict(vote_counter),
|
||||
num_samples=len(valid_responses),
|
||||
all_responses=valid_responses,
|
||||
)
|
||||
|
||||
def _filter_valid_responses(
|
||||
self,
|
||||
responses: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
valid = [] # Valid responses
|
||||
|
||||
for r in responses:
|
||||
if not isinstance(r, dict):
|
||||
continue
|
||||
if "ANSWER" not in r:
|
||||
continue
|
||||
|
||||
# Check valid class
|
||||
if self.valid_classes and r["ANSWER"] not in self.valid_classes:
|
||||
continue
|
||||
|
||||
# Normalize CONFIDENCE
|
||||
conf = r.get("CONFIDENCE", 0.5)
|
||||
if isinstance(conf, str):
|
||||
try:
|
||||
conf = float(conf)
|
||||
except ValueError:
|
||||
conf = 0.5
|
||||
r["CONFIDENCE"] = max(0.0, min(1.0, conf))
|
||||
|
||||
# Set default REASON
|
||||
if "REASON" not in r:
|
||||
r["REASON"] = ""
|
||||
|
||||
valid.append(r)
|
||||
|
||||
return valid
|
||||
|
||||
def _resolve_majority(
|
||||
self,
|
||||
vote_counter: Counter,
|
||||
responses: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
Determine majority answer (use average confidence for tie-breaking).
|
||||
"""
|
||||
max_count = vote_counter.most_common(1)[0][1]
|
||||
tied_answers = [a for a, c in vote_counter.items() if c == max_count]
|
||||
|
||||
if len(tied_answers) == 1:
|
||||
return tied_answers[0]
|
||||
|
||||
# Tie-breaking: select answer with highest average confidence
|
||||
best_answer = None
|
||||
best_avg_conf = -1.0
|
||||
|
||||
for answer in tied_answers:
|
||||
answer_responses = [r for r in responses if r["ANSWER"] == answer]
|
||||
avg_conf = np.mean([r["CONFIDENCE"] for r in answer_responses])
|
||||
|
||||
if avg_conf > best_avg_conf:
|
||||
best_avg_conf = avg_conf
|
||||
best_answer = answer
|
||||
|
||||
return best_answer
|
||||
|
||||
def _empty_result(self) -> VotingResult:
|
||||
"""Return empty result when no valid responses exist."""
|
||||
return VotingResult(
|
||||
answer="UNKNOWN",
|
||||
reason="No valid responses received",
|
||||
avg_confidence=0.0,
|
||||
consistency=0.0,
|
||||
vote_distribution={},
|
||||
num_samples=0,
|
||||
all_responses=[],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compute_agreement_matrix(
|
||||
responses: List[Dict[str, Any]]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute pairwise agreement matrix between responses (for analysis).
|
||||
|
||||
Creates an N x N matrix where entry (i, j) is 1.0 if responses i and j
|
||||
have the same ANSWER, and 0.0 otherwise. Useful for analyzing
|
||||
consistency patterns across samples.
|
||||
|
||||
Args:
|
||||
responses: List of response dictionaries with ANSWER keys
|
||||
"""
|
||||
n = len(responses)
|
||||
matrix = np.zeros((n, n))
|
||||
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
if responses[i].get("ANSWER") == responses[j].get("ANSWER"):
|
||||
matrix[i, j] = 1.0
|
||||
|
||||
return matrix
|
||||
|
||||
153
sc/core/model.py
153
sc/core/model.py
@@ -1,153 +0,0 @@
|
||||
import os
|
||||
import asyncio
|
||||
import requests
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_together import ChatTogether
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
|
||||
def load_models(models, temperature=0.0, num_ctx=15000):
|
||||
model_pool = AsyncModelPool()
|
||||
for model in models:
|
||||
model_pool.add_model(Model(model, temperature=temperature, num_ctx=num_ctx))
|
||||
model_pool.init_models()
|
||||
return model_pool
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, model, temperature, num_ctx):
|
||||
if model.startswith("ollama:"):
|
||||
model = model.replace("ollama:", "")
|
||||
if "url:" in model: # custom parsing for local ollama models
|
||||
model = model.replace("url:", "")
|
||||
base_url = model.split("/")[0]
|
||||
model_type = model.split("/")[1]
|
||||
# self.model = ChatOllama(
|
||||
# model=model_type,
|
||||
# base_url=f"http://{base_url}",
|
||||
# temperature=temperature,
|
||||
# num_ctx=num_ctx,
|
||||
# )
|
||||
self.model = None
|
||||
self.base_url = f"http://{base_url}/api/chat"
|
||||
self.model_type = model_type
|
||||
self.temperature = temperature
|
||||
self.num_ctx = num_ctx
|
||||
else:
|
||||
self.model = ChatOllama(
|
||||
model=model.replace("ollama:", ""),
|
||||
temperature=temperature,
|
||||
num_ctx=num_ctx,
|
||||
)
|
||||
elif model.startswith("together"):
|
||||
if "TOGETHER_API_KEY" not in os.environ:
|
||||
print("[!] TOGETHER_API_KEY is not set")
|
||||
assert 0
|
||||
self.model = ChatTogether(
|
||||
model=model.replace("together:", ""),
|
||||
temperature=temperature,
|
||||
max_tokens=num_ctx,
|
||||
max_retries=3,
|
||||
)
|
||||
elif model.startswith("openai"):
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
print("[!] OPENAI_API_KEY is not set")
|
||||
assert 0
|
||||
self.model = ChatOpenAI(
|
||||
model=model.replace("openai:", ""),
|
||||
temperature=temperature,
|
||||
)
|
||||
else:
|
||||
self.model = init_chat_model(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
def invoke(self, messages, logprobs=False, top_logprobs=0):
|
||||
try:
|
||||
if self.model:
|
||||
response = self.model.invoke(messages)
|
||||
return response
|
||||
else:
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = msg.type
|
||||
role = "user" if role == "human" else "assistant"
|
||||
content = msg.content
|
||||
converted_messages.append({"role": role, "content": content})
|
||||
response = requests.post(self.base_url, json={
|
||||
"model": self.model_type,
|
||||
"messages": converted_messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": self.temperature,
|
||||
"num_ctx": self.num_ctx,
|
||||
},
|
||||
"logprobs": logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
})
|
||||
response = response.json()
|
||||
resp_msg = AIMessage(content=response["message"]["content"])
|
||||
if logprobs:
|
||||
return resp_msg, response["logprobs"]
|
||||
else:
|
||||
return resp_msg
|
||||
return resp_msg, response["logprobs"]
|
||||
except Exception as e:
|
||||
print(f"[Error] Error occurred while invoking LLM: {e}")
|
||||
return e
|
||||
|
||||
|
||||
class AsyncModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
async def invoke(self, content, logprobs=False, top_logprobs=0):
|
||||
loop = asyncio.get_event_loop()
|
||||
if logprobs:
|
||||
response, logprobs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.invoke(content, logprobs=logprobs, top_logprobs=top_logprobs),
|
||||
)
|
||||
return response, logprobs
|
||||
else:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.invoke(content),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class AsyncModelPool:
|
||||
def __init__(self):
|
||||
self.models = []
|
||||
self._available_models = None
|
||||
self._model_semaphore = None
|
||||
|
||||
def add_model(self, model):
|
||||
self.models.append(model)
|
||||
|
||||
def init_models(self):
|
||||
print(f"Initializing {len(self.models)} models...")
|
||||
self._available_models = asyncio.Queue()
|
||||
for model in self.models:
|
||||
async_model = AsyncModel(model)
|
||||
self._available_models.put_nowait(async_model)
|
||||
self._model_semaphore = asyncio.Semaphore(len(self.models))
|
||||
|
||||
async def invoke(self, content, logprobs=False, top_logprobs=0):
|
||||
if self._available_models is None:
|
||||
raise RuntimeError("Model pool not initialized. Call init_models() first.")
|
||||
async_model = await self._available_models.get()
|
||||
try:
|
||||
if logprobs:
|
||||
response, logprobs = await async_model.invoke(content, logprobs=logprobs, top_logprobs=top_logprobs)
|
||||
return response, logprobs
|
||||
else:
|
||||
response = await async_model.invoke(content)
|
||||
return response
|
||||
finally:
|
||||
self._available_models.put_nowait(async_model)
|
||||
@@ -1,188 +0,0 @@
|
||||
"""
|
||||
Model Utilities for Self-Consistency Experiments
|
||||
|
||||
This module provides model loading and management utilities with
|
||||
configurable temperature support for Self-Consistency experiments.
|
||||
|
||||
- Temperature-configurable model loading
|
||||
- Async model pool for parallel inference
|
||||
- Support for Ollama and other LangChain-compatible models
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Any
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import HumanMessage, BaseMessage
|
||||
|
||||
|
||||
def load_models_with_temperature(
|
||||
models: List[str],
|
||||
temperature: float = 0.0
|
||||
) -> "AsyncModelPoolSC":
|
||||
"""
|
||||
Load models with specified temperature into an async pool.
|
||||
|
||||
Creates a pool of models that can be used for parallel async inference.
|
||||
Each model in the pool is initialized with the same temperature setting.
|
||||
|
||||
Args:
|
||||
models: List of model specification strings.
|
||||
Supported formats:
|
||||
- "ollama:url:host:port/model_name" for remote Ollama
|
||||
- "ollama:model_name" for local Ollama
|
||||
- Standard LangChain model strings
|
||||
temperature: LLM sampling temperature (default: 0.0)
|
||||
"""
|
||||
model_pool = AsyncModelPoolSC()
|
||||
for model_str in models:
|
||||
model_pool.add_model(ModelSC(model_str, temperature=temperature))
|
||||
model_pool.init_models()
|
||||
return model_pool
|
||||
|
||||
|
||||
class ModelSC:
|
||||
"""
|
||||
Self-Consistency model wrapper with temperature support.
|
||||
|
||||
Wraps LangChain chat models with configurable temperature for
|
||||
Self-Consistency experiments. Supports Ollama (local and remote)
|
||||
and other LangChain-compatible models.
|
||||
|
||||
Attributes:
|
||||
model: Underlying LangChain chat model
|
||||
temperature: Configured sampling temperature
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, temperature: float = 0.0):
|
||||
"""
|
||||
Initialize model with specified temperature.
|
||||
|
||||
Args:
|
||||
model: Model specification string
|
||||
Formats:
|
||||
- "ollama:url:host:port/model" - Remote Ollama instance
|
||||
- "ollama:model" - Local Ollama instance
|
||||
- Other strings - Passed to LangChain init_chat_model
|
||||
temperature: Sampling temperature (0.0 = deterministic)
|
||||
|
||||
Raises:
|
||||
ValueError: If model string format is invalid
|
||||
"""
|
||||
self.temperature = temperature
|
||||
self._model_str = model
|
||||
|
||||
if model.startswith("ollama:"):
|
||||
self._init_ollama_model(model, temperature)
|
||||
else:
|
||||
self._init_langchain_model(model, temperature)
|
||||
|
||||
def _init_ollama_model(self, model: str, temperature: float) -> None:
|
||||
"""
|
||||
Initialize Ollama model (local or remote).
|
||||
|
||||
Args:
|
||||
model: Ollama model string (ollama:url:host:port/model or ollama:model)
|
||||
temperature: Sampling temperature
|
||||
"""
|
||||
model = model.replace("ollama:", "")
|
||||
|
||||
if "url:" in model:
|
||||
model = model.replace("url:", "")
|
||||
parts = model.split("/")
|
||||
base_url = parts[0]
|
||||
|
||||
if not base_url.startswith("http"):
|
||||
base_url = "http://" + base_url
|
||||
|
||||
model_type = parts[1] if len(parts) > 1 else "llama2"
|
||||
|
||||
self.model = ChatOllama(
|
||||
model=model_type,
|
||||
base_url=base_url,
|
||||
temperature=temperature,
|
||||
num_ctx=12000,
|
||||
)
|
||||
else:
|
||||
self.model = ChatOllama(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
num_ctx=12000,
|
||||
)
|
||||
|
||||
def _init_langchain_model(self, model: str, temperature: float) -> None:
|
||||
self.model = init_chat_model(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
def invoke(self, messages: List[BaseMessage]) -> Any:
|
||||
response = self.model.invoke(messages)
|
||||
return response
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the model."""
|
||||
return f"ModelSC(model={self._model_str}, temperature={self.temperature})"
|
||||
|
||||
|
||||
class AsyncModelSC:
|
||||
def __init__(self, model: ModelSC):
|
||||
self.model = model
|
||||
|
||||
async def invoke(self, messages: List[BaseMessage]) -> Any:
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.invoke(messages),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class AsyncModelPoolSC:
|
||||
def __init__(self):
|
||||
self.models: List[ModelSC] = []
|
||||
self._available_models: asyncio.Queue = None
|
||||
self._model_semaphore: asyncio.Semaphore = None
|
||||
|
||||
def add_model(self, model: ModelSC) -> None:
|
||||
self.models.append(model)
|
||||
|
||||
def init_models(self) -> None:
|
||||
if not self.models:
|
||||
raise RuntimeError("No models added. Call add_model() first.")
|
||||
|
||||
self._available_models = asyncio.Queue()
|
||||
for model in self.models:
|
||||
async_model = AsyncModelSC(model)
|
||||
self._available_models.put_nowait(async_model)
|
||||
self._model_semaphore = asyncio.Semaphore(len(self.models))
|
||||
|
||||
def warmup(self) -> None:
|
||||
print(f"[ModelPool] Warming up {len(self.models)} models...")
|
||||
for i, model in enumerate(self.models):
|
||||
try:
|
||||
model.invoke([HumanMessage(content="Hello world!")])
|
||||
print(f"[ModelPool] Model {i+1}/{len(self.models)} warmed up")
|
||||
except Exception as e:
|
||||
print(f"[ModelPool] Model {i+1} warmup failed: {e}")
|
||||
print("[ModelPool] All models warmed up")
|
||||
|
||||
async def invoke(self, messages: List[BaseMessage]) -> Any:
|
||||
if self._available_models is None:
|
||||
raise RuntimeError("Model pool not initialized. Call init_models() first.")
|
||||
|
||||
async_model = await self._available_models.get()
|
||||
try:
|
||||
response = await async_model.invoke(messages)
|
||||
return response
|
||||
finally:
|
||||
# Always return model to pool
|
||||
self._available_models.put_nowait(async_model)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return len(self.models)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"AsyncModelPoolSC(size={self.size})"
|
||||
154
sc/core/queue.py
154
sc/core/queue.py
@@ -1,154 +0,0 @@
|
||||
from collections import deque
|
||||
import random
|
||||
|
||||
class Queue:
|
||||
def __init__(self, dataset_size, capacity=5):
|
||||
self.dataset_size = dataset_size
|
||||
self.classes = list(dataset_size.keys())
|
||||
initial_cases = [
|
||||
[random.choice(dataset_size[cls]) for cls in self.classes]
|
||||
for _ in range(capacity)
|
||||
]
|
||||
self._queue = deque(initial_cases, maxlen=capacity)
|
||||
self._input_time = {tuple(case): 0 for case in initial_cases}
|
||||
self._current_time = 0
|
||||
self._eviction_history = []
|
||||
self._usage_count = {tuple(case): 0 for case in initial_cases}
|
||||
|
||||
def set_current_time(self, sample_idx: int):
|
||||
self._current_time = sample_idx
|
||||
|
||||
def push(self, index):
|
||||
if list(index) in [list(c) for c in self._queue]:
|
||||
return None
|
||||
evicted = None
|
||||
if len(self._queue) == self._queue.maxlen:
|
||||
evicted = self._queue.popleft()
|
||||
self._record_eviction_stats(evicted)
|
||||
self._queue.append(list(index))
|
||||
self._register_stats(index)
|
||||
return evicted
|
||||
|
||||
def pop(self):
|
||||
return self._queue.popleft() if self._queue else None
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Make Queue iterable by yielding all elements in the queue.
|
||||
|
||||
This allows the queue to be used in for loops:
|
||||
for ex_index in ex_queue:
|
||||
# process ex_index
|
||||
"""
|
||||
for idx in self._queue:
|
||||
yield idx
|
||||
|
||||
def _register_stats(self, case):
|
||||
key = tuple(case)
|
||||
self._input_time[key] = self._current_time
|
||||
self._usage_count[key] = 0
|
||||
|
||||
def _record_eviction_stats(self, case):
|
||||
key = tuple(case)
|
||||
if key in self._input_time:
|
||||
duration = self._current_time - self._input_time[key]
|
||||
self._eviction_history.append({
|
||||
"case": case,
|
||||
"duration": duration,
|
||||
"usage": self._usage_count.get(key, 0),
|
||||
"evicted_at": self._current_time
|
||||
})
|
||||
del self._input_time[key]
|
||||
if key in self._usage_count:
|
||||
del self._usage_count[key]
|
||||
|
||||
def update_by_confidence(self, confidence_map):
|
||||
if not confidence_map:
|
||||
return None
|
||||
current_items = list(self._queue)
|
||||
items_with_score = []
|
||||
for i, item in enumerate(current_items):
|
||||
items_with_score.append({
|
||||
"item": item,
|
||||
"score": confidence_map.get(i, -1.0)
|
||||
})
|
||||
|
||||
items_with_score.sort(key=lambda x: x["score"], reverse=True)
|
||||
self._queue = deque([x["item"] for x in items_with_score], maxlen=self._queue.maxlen)
|
||||
|
||||
evicted = None
|
||||
if len(self._queue) == self._queue.maxlen:
|
||||
evicted = self._queue.pop()
|
||||
self._record_eviction_stats(evicted)
|
||||
new_case = [random.choice(self.dataset_size[cls]) for cls in self.classes]
|
||||
self._queue.append(new_case)
|
||||
self._register_stats(new_case)
|
||||
|
||||
return evicted
|
||||
|
||||
def update_by_consistency(self, consistency_map):
|
||||
"""Update queue based on consistency scores (agreement ratio among agents)."""
|
||||
if not consistency_map:
|
||||
return None
|
||||
current_items = list(self._queue)
|
||||
items_with_score = []
|
||||
for i, item in enumerate(current_items):
|
||||
items_with_score.append({
|
||||
"item": item,
|
||||
"score": consistency_map.get(i, -1.0)
|
||||
})
|
||||
|
||||
items_with_score.sort(key=lambda x: x["score"], reverse=True)
|
||||
self._queue = deque([x["item"] for x in items_with_score], maxlen=self._queue.maxlen)
|
||||
|
||||
evicted = None
|
||||
if len(self._queue) == self._queue.maxlen:
|
||||
evicted = self._queue.pop()
|
||||
self._record_eviction_stats(evicted)
|
||||
new_case = [random.choice(self.dataset_size[cls]) for cls in self.classes]
|
||||
self._queue.append(new_case)
|
||||
self._register_stats(new_case)
|
||||
|
||||
return evicted
|
||||
|
||||
def increment_usage(self, queue_indices):
|
||||
"""Increment usage count for specified queue indices."""
|
||||
for idx in queue_indices:
|
||||
if 0 <= idx < len(self._queue):
|
||||
key = tuple(self._queue[idx])
|
||||
self._usage_count[key] = self._usage_count.get(key, 0) + 1
|
||||
|
||||
def get_instance_id(self):
|
||||
return id(self)
|
||||
|
||||
def get_state_with_stats(self):
|
||||
result = []
|
||||
for case in self._queue:
|
||||
key = tuple(case)
|
||||
age = self._current_time - self._input_time.get(key, self._current_time)
|
||||
result.append({
|
||||
"case": case,
|
||||
"age": age,
|
||||
"usage": self._usage_count.get(key, 0)
|
||||
})
|
||||
return result
|
||||
|
||||
def get_survival_stats(self):
|
||||
return self._eviction_history
|
||||
|
||||
def get_survival_summary(self):
|
||||
if not self._eviction_history:
|
||||
return {
|
||||
"total_evicted": 0,
|
||||
"avg_survival": 0,
|
||||
"max_survival": 0,
|
||||
"avg_usage": 0
|
||||
}
|
||||
durations = [x["duration"] for x in self._eviction_history]
|
||||
usages = [x["usage"] for x in self._eviction_history]
|
||||
return {
|
||||
"total_evicted": len(durations),
|
||||
"avg_survival": sum(durations) / len(durations),
|
||||
"max_survival": max(durations),
|
||||
"avg_usage": sum(usages) / len(usages) if usages else 0
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
"""
|
||||
Self-Consistency Agent for Sleep Stage Classification
|
||||
|
||||
This module implements an agent that uses Self-Consistency methodology:
|
||||
- Sample N times with the same prompt (configurable temperature)
|
||||
- Output REASON, CONFIDENCE, ANSWER in JSON format
|
||||
- Aggregate final answer via Majority Voting
|
||||
|
||||
The agent extends the base Agent class and adds Self-Consistency
|
||||
specific functionality including multi-sampling and voting.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
from sc.core.agent import Agent
|
||||
|
||||
|
||||
class SCAgent(Agent):
|
||||
"""
|
||||
Self-Consistency based sleep stage classification agent.
|
||||
|
||||
This agent samples N times for the same input and uses majority voting
|
||||
to determine the final answer. It supports confidence-based tie-breaking
|
||||
and provides detailed metrics about consistency.
|
||||
|
||||
Attributes:
|
||||
task_info: Description of the classification task
|
||||
classes_info: List of valid class labels
|
||||
sensor_info: Information about sensor data
|
||||
sample: Current sample being classified
|
||||
examples: ICL examples for context
|
||||
voter: MajorityVoting instance for aggregation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
index: int,
|
||||
model_pool,
|
||||
task_info: str,
|
||||
classes_info: List[str],
|
||||
sensor_info: str,
|
||||
sample: Dict[str, Any],
|
||||
examples: List[Dict[str, Any]],
|
||||
log_path: str,
|
||||
):
|
||||
"""
|
||||
Initialize Self-Consistency Agent.
|
||||
|
||||
Args:
|
||||
name: Agent identifier name
|
||||
index: Index of the agent in the dataset
|
||||
model_pool: Async model pool for LLM inference
|
||||
task_info: Description of the classification task
|
||||
classes_info: List of valid class labels
|
||||
sensor_info: Information about sensor data format
|
||||
sample: Sample to be classified
|
||||
examples: ICL examples for few-shot learning
|
||||
log_path: Directory path for saving logs
|
||||
"""
|
||||
super().__init__(
|
||||
name=name + f"_{index}",
|
||||
model_pool=model_pool,
|
||||
log_path=log_path,
|
||||
)
|
||||
|
||||
self.task_info = task_info
|
||||
self.classes_info = classes_info
|
||||
self.sensor_info = sensor_info
|
||||
self.sample = sample
|
||||
self.examples = examples
|
||||
self.index = index # Store index for later retrieval
|
||||
self._init_system_message()
|
||||
|
||||
def _init_system_message(self) -> None:
|
||||
"""Initialize the system message that defines the agent's role."""
|
||||
content = (
|
||||
f"You are a {self.name} agent that interprets sensor data to solve a task.\n"
|
||||
f"You have the following information about the task:\n"
|
||||
f"{self.task_info}\n\n"
|
||||
f"You have the following information about the sensor data:\n"
|
||||
f"{self.sensor_info}\n\n"
|
||||
"Your goal is to analyze the features and "
|
||||
"provide a reasoned answer with your confidence level."
|
||||
)
|
||||
self.set_system_message(content)
|
||||
|
||||
def _format_feature(self, value: Any) -> str:
|
||||
"""
|
||||
Format a feature value for display.
|
||||
|
||||
Uses scientific notation for very large or very small numbers,
|
||||
and two decimal places for regular floats.
|
||||
|
||||
Args:
|
||||
value: Feature value to format
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if isinstance(value, float):
|
||||
if abs(value) >= 1e4 or (abs(value) < 1e-2 and value != 0):
|
||||
return f"{value:.2e}"
|
||||
return f"{value:.2f}"
|
||||
return str(value)
|
||||
|
||||
def _gen_example_info(self) -> str:
|
||||
"""
|
||||
Generate string representation of ICL examples.
|
||||
|
||||
Returns:
|
||||
Formatted string with example features and labels,
|
||||
or empty string if no examples provided
|
||||
"""
|
||||
if not self.examples or len(self.examples) == 0:
|
||||
return ""
|
||||
|
||||
example_info = (
|
||||
"**Examples**\n"
|
||||
"Sensor values might not always align with your inherent "
|
||||
"knowledge due to differences in data collection or processing. "
|
||||
"So, we included a few labeled examples to help your interpretation:\n"
|
||||
)
|
||||
|
||||
for example in self.examples:
|
||||
example_info += f"*Example of {example['label']}*:\n"
|
||||
for k, v in example["features"].items():
|
||||
example_info += f" - {k}: {self._format_feature(v)}\n"
|
||||
example_info += "\n"
|
||||
|
||||
return example_info.strip()
|
||||
|
||||
def _gen_feature_info(self) -> str:
|
||||
"""
|
||||
Generate string representation of current sample features.
|
||||
|
||||
Combines ICL example information with current sample features
|
||||
to provide full context for classification.
|
||||
|
||||
Returns:
|
||||
Formatted string with example and sample features
|
||||
"""
|
||||
feature_info = f"{self.name} features:\n"
|
||||
|
||||
# Add ICL example information if available
|
||||
example_info = self._gen_example_info()
|
||||
if example_info:
|
||||
feature_info += f"{example_info}\n\n"
|
||||
|
||||
# Add current sample features
|
||||
feature_info += "**Current sample features**:\n"
|
||||
for k, v in self.sample["features"].items():
|
||||
feature_info += f" - {k}: {self._format_feature(v)}\n"
|
||||
|
||||
return feature_info.strip()
|
||||
|
||||
async def interpret(self) -> str:
|
||||
feature_info = self._gen_feature_info()
|
||||
prompt = (
|
||||
f"You have received sensor features from {self.name} modality:\n"
|
||||
f"{feature_info}\n\n"
|
||||
f"Please provide your answer for the task among {self.classes_info} "
|
||||
"and the reasoning for your answer.\n"
|
||||
"Also, please provide your confidence level for the answer as a float between 0.0 and 1.0.\n"
|
||||
"Note that the sensor features might be wrong due to the data collection or processing.\n"
|
||||
"You can evaluate the quality of the features by checking the examples you have.\n\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
' "REASON": "<Your detailed reasoning for the classification>",\n'
|
||||
' "CONFIDENCE": <Your confidence as a float between 0.0 and 1.0>,\n'
|
||||
f' "ANSWER": "<Your answer among {self.classes_info}>"\n'
|
||||
"}\n\n"
|
||||
"Do not include any additional text outside the JSON."
|
||||
)
|
||||
|
||||
response = await self.invoke(prompt)
|
||||
parsed_response = self.safe_parse_json(response)
|
||||
parsed_response = await self.validate_response(
|
||||
parsed_response, ["REASON", "CONFIDENCE", "ANSWER"]
|
||||
)
|
||||
# Add example index to response for queue update
|
||||
if parsed_response:
|
||||
parsed_response["_example_idx"] = self.index
|
||||
return parsed_response
|
||||
426
sc/debug_log.py
426
sc/debug_log.py
@@ -1,426 +0,0 @@
|
||||
def _enabled(config=None):
|
||||
if config is None:
|
||||
return True
|
||||
return config.get("debug", True)
|
||||
|
||||
|
||||
def log(message, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(message)
|
||||
|
||||
|
||||
def warn_no_examples():
|
||||
print("[WARN] No examples found for dataloader. Skipping task.")
|
||||
|
||||
|
||||
def log_queue_state_before(processed_count, user_info, ex_queue, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print(
|
||||
f"[Sample {processed_count}] {user_info} - Queue State BEFORE Processing "
|
||||
f"(Instance ID: {ex_queue.get_instance_id()}):"
|
||||
)
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(
|
||||
f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})"
|
||||
)
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def warn_no_agents(processed_count):
|
||||
print(f"[WARN] No agents added for sample {processed_count}. Skipping.")
|
||||
|
||||
|
||||
def warn_interpretation_failed(processed_count):
|
||||
print(f"[WARN] Interpretation failed for sample {processed_count}. Skipping.")
|
||||
|
||||
|
||||
def log_tracking(
|
||||
processed_count,
|
||||
user_info,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
avg_confidence_so_far,
|
||||
config=None,
|
||||
):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[TRACKING] Sample {processed_count} | {user_info}")
|
||||
print(
|
||||
f" Answer: {answer} | GT: {ground_truth} | "
|
||||
f"{'✓ CORRECT' if is_correct else '✗ WRONG'}"
|
||||
)
|
||||
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
|
||||
print(" ─────────────────────────────────────────────────────")
|
||||
print(
|
||||
f" Cumulative Accuracy: {cumulative_accuracy:.4f} "
|
||||
f"({cumulative_correct}/{processed_count + 1})"
|
||||
)
|
||||
print(
|
||||
f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}"
|
||||
)
|
||||
print(f" Avg Confidence (so far): {avg_confidence_so_far:.4f}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def log_queue_state_after(processed_count, ex_queue, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n[Sample {processed_count}] Queue State AFTER Update:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(
|
||||
f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})"
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
def warn_no_responses(processed_count):
|
||||
print(f"[WARN] No responses returned, falling back to basic update.")
|
||||
|
||||
|
||||
def log_final_queue_stats(user_info, survival_summary, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[FINAL] Queue Survival Statistics for {user_info}")
|
||||
print(f" Total Evicted Cases: {survival_summary['total_evicted']}")
|
||||
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
|
||||
print(f" Max Survival: {survival_summary['max_survival']} samples")
|
||||
print(f" Min Survival: {survival_summary['min_survival']} samples")
|
||||
print(f" Avg Usage Count: {survival_summary['avg_usage']:.2f}")
|
||||
print(f" Max Usage Count: {survival_summary['max_usage']}")
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def warn_no_user_dirs(data_path):
|
||||
print(f"[WARN] No user directories found in {data_path}")
|
||||
|
||||
|
||||
def log_found_users(users, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(
|
||||
f"[INFO] Found {len(users)} users: {users[:5]}"
|
||||
f"{'...' if len(users) > 5 else ''}"
|
||||
)
|
||||
|
||||
|
||||
def warn_skip_user_no_test_data(user):
|
||||
print(f"[WARN] Skipping user {user} - no test data available")
|
||||
|
||||
|
||||
def warn_skip_user_no_example_data(user):
|
||||
print(f"[WARN] Skipping user {user} - no example data available")
|
||||
|
||||
|
||||
def log_main_loading_config(config_path, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[MAIN] Loading config: {config_path}")
|
||||
|
||||
|
||||
def log_main_config(config, config_enabled=True):
|
||||
if not config_enabled:
|
||||
return
|
||||
print("=" * 60)
|
||||
print("SELF-CONSISTENCY EXPERIMENT CONFIGURATION")
|
||||
print("=" * 60)
|
||||
print(f" Data path: {config.get('data_path', 'N/A')}")
|
||||
print(f" Log path: {config.get('log_path', 'N/A')}")
|
||||
print(f" Num ICL examples: {config.get('num_examples', 1)}")
|
||||
print(f" Num seeds: {config.get('num_seeds', 1)}")
|
||||
print(f" Num SC samples: {config.get('num_sc_samples', 5)}")
|
||||
print(f" Temperature: {config.get('temperature', 0.0)}")
|
||||
print(f" Sample rate: 1/{config.get('sample_rate', 10)}")
|
||||
print(f" Num models: {len(config.get('models', []))}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def log_main_start(config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("[MAIN] Starting experiments...")
|
||||
|
||||
|
||||
def error_task_failed(result):
|
||||
print(f"[ERROR] Task failed with exception: {result}")
|
||||
|
||||
|
||||
def log_total_results(count, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[MAIN] Total results collected: {count}")
|
||||
|
||||
|
||||
def log_experiment_results(stats, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("\n" + "=" * 60)
|
||||
print("EXPERIMENT RESULTS")
|
||||
print("=" * 60)
|
||||
print(f" Total samples: {stats.get('total_samples', 0)}")
|
||||
print(f" Accuracy: {stats.get('accuracy', 0):.4f}")
|
||||
print(f" Avg Confidence: {stats.get('avg_confidence', 0):.4f}")
|
||||
print(f" Avg Consistency: {stats.get('avg_consistency', 0):.4f}")
|
||||
print(
|
||||
" High Consistency (>=0.8) Accuracy: "
|
||||
f"{stats.get('high_consistency_accuracy', 0):.4f}"
|
||||
)
|
||||
print(f" High Consistency Samples: {stats.get('high_consistency_samples', 0)}")
|
||||
print("\n Class-wise Accuracy:")
|
||||
for cls, acc in stats.get("class_accuracy", {}).items():
|
||||
print(f" {cls}: {acc:.4f}")
|
||||
|
||||
|
||||
def log_temporal_analysis(temporal, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
if not temporal:
|
||||
return
|
||||
print("\n" + "-" * 60)
|
||||
print(" TEMPORAL ANALYSIS (Caching Effect)")
|
||||
print("-" * 60)
|
||||
print(f" First Half Accuracy: {temporal.get('first_half_accuracy', 0):.4f}")
|
||||
print(f" Second Half Accuracy: {temporal.get('second_half_accuracy', 0):.4f}")
|
||||
improvement = temporal.get("accuracy_improvement", 0)
|
||||
improvement_sign = "+" if improvement >= 0 else ""
|
||||
print(f" Improvement: {improvement_sign}{improvement:.4f}")
|
||||
quartiles = temporal.get("quartile_accuracies", [])
|
||||
if quartiles:
|
||||
print(
|
||||
f" Quartile Accuracies: Q1={quartiles[0]:.4f}"
|
||||
+ (f", Q2={quartiles[1]:.4f}" if len(quartiles) > 1 else "")
|
||||
+ (f", Q3={quartiles[2]:.4f}" if len(quartiles) > 2 else "")
|
||||
+ (f", Q4={quartiles[3]:.4f}" if len(quartiles) > 3 else "")
|
||||
)
|
||||
print(f"\n First Half Confidence: {temporal.get('first_half_confidence', 0):.4f}")
|
||||
print(f" Second Half Confidence: {temporal.get('second_half_confidence', 0):.4f}")
|
||||
conf_improvement = temporal.get("confidence_improvement", 0)
|
||||
conf_sign = "+" if conf_improvement >= 0 else ""
|
||||
print(f" Confidence Change: {conf_sign}{conf_improvement:.4f}")
|
||||
|
||||
|
||||
def log_queue_stats(queue_stats, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
if not queue_stats:
|
||||
return
|
||||
print("\n" + "-" * 60)
|
||||
print(" QUEUE SURVIVAL STATISTICS")
|
||||
print("-" * 60)
|
||||
print(f" Total Evicted Cases: {queue_stats.get('total_evicted', 0)}")
|
||||
print(f" Avg Survival: {queue_stats.get('avg_survival', 0):.2f} samples")
|
||||
print(f" Max Survival: {queue_stats.get('max_survival', 0)} samples")
|
||||
print(f" Min Survival: {queue_stats.get('min_survival', 0)} samples")
|
||||
print(f" Avg Usage Count: {queue_stats.get('avg_usage', 0):.2f}")
|
||||
print(f" Max Usage Count: {queue_stats.get('max_usage', 0)}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def log_save_statistics(stats_path, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[SAVE] Statistics saved to: {stats_path}")
|
||||
|
||||
|
||||
def log_save_results(results_path, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[SAVE] Results saved to: {results_path}")
|
||||
|
||||
|
||||
def log_save_config(config_path, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[SAVE] Config saved to: {config_path}")
|
||||
|
||||
|
||||
def log_results_saved(log_path, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[MAIN] Results saved to: {log_path}")
|
||||
|
||||
|
||||
def log_policy_experiment_header(policy_label, user_id, shuffle_seed, total_samples, queue_size, config=None, policy_note=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'='*80}")
|
||||
print(f"{policy_label} QUEUE POLICY EXPERIMENT")
|
||||
print(f"User: {user_id} | Shuffle Seed: {shuffle_seed}")
|
||||
print(f"Total samples: {total_samples} | Queue size: {queue_size}")
|
||||
if policy_note:
|
||||
print(policy_note)
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
|
||||
def log_policy_queue_state(policy_label, processed_count, user_id, ex_queue, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[Sample {processed_count}] User {user_id} | {policy_label} Policy")
|
||||
print("Queue State BEFORE Processing:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def log_policy_result(policy_label, processed_count, answer, ground_truth, is_correct, avg_confidence, consistency, cumulative_accuracy, cumulative_correct, window_accuracy, recent_results, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[RESULT] Sample {processed_count} | {policy_label} Policy")
|
||||
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
|
||||
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
|
||||
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
|
||||
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def log_confidence_map(responses, confidence_map, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("\n[CONFIDENCE MAP] Per-agent confidence scores:")
|
||||
for idx, conf in sorted(confidence_map.items()):
|
||||
ans = responses[idx].get("ANSWER", "?")
|
||||
print(f" Queue[{idx}]: answer={ans}, confidence={conf:.4f}")
|
||||
|
||||
|
||||
def log_consistency_map(responses, consistency_map, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("\n[CONSISTENCY MAP] Per-agent consistency scores:")
|
||||
for idx, cons in sorted(consistency_map.items()):
|
||||
ans = responses[idx].get("ANSWER", "?")
|
||||
print(f" Queue[{idx}]: answer={ans}, consistency={cons:.4f}")
|
||||
|
||||
|
||||
def log_policy_queue_state_after(policy_label, processed_count, ex_queue, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n[Sample {processed_count}] Queue State AFTER {policy_label} Update:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
|
||||
|
||||
|
||||
def log_final_policy_survival(policy_label, user_id, shuffle_seed, survival_summary, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[FINAL] Queue Survival Statistics - {policy_label} Policy")
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Evicted: {survival_summary['total_evicted']}")
|
||||
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
|
||||
print(f" Avg Usage: {survival_summary['avg_usage']:.2f}")
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def log_policy_main_header(policy_label, user_id, shuffle_seed, queue_size, num_models, log_path, config=None, policy_note=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("=" * 80)
|
||||
print(f"{policy_label} QUEUE POLICY EXPERIMENT")
|
||||
print("=" * 80)
|
||||
print(f" User ID: {user_id}")
|
||||
print(f" Shuffle Seed: {shuffle_seed}")
|
||||
print(f" Queue Size: {queue_size}")
|
||||
print(f" SC Samples (Agents): {num_models}")
|
||||
print(f" Log Path: {log_path}")
|
||||
if policy_note:
|
||||
print(policy_note)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def log_policy_loading_data(user_id, shuffle_seed, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n[MAIN] Loading shuffled data for user {user_id}, seed {shuffle_seed}...")
|
||||
|
||||
|
||||
def log_policy_loading_models(config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("\n[MAIN] Loading models...")
|
||||
|
||||
|
||||
def log_policy_start(config=None, label="experiment"):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n[MAIN] Starting {label}...")
|
||||
|
||||
|
||||
def log_policy_complete_summary(policy_label, user_id, shuffle_seed, stats, stage_accuracy, stage_counts, temporal, config=None, expected_note=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("\n" + "=" * 80)
|
||||
print(f"EXPERIMENT COMPLETE - {policy_label} POLICY")
|
||||
print("=" * 80)
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Samples: {stats.get('total_samples', 0)}")
|
||||
print(f" Overall Accuracy: {stats.get('accuracy', 0):.4f}")
|
||||
print(f" Macro F1: {stats.get('macro_f1', 0):.4f}")
|
||||
print("\n Per-Stage Accuracy:")
|
||||
for stage, acc in stage_accuracy.items():
|
||||
count = stage_counts.get(stage, 0)
|
||||
print(f" {stage}: {acc:.4f} (n={count})")
|
||||
print("\n Temporal Analysis:")
|
||||
print(f" First Half: {temporal.get('first_half_accuracy', 0):.4f}")
|
||||
print(f" Second Half: {temporal.get('second_half_accuracy', 0):.4f}")
|
||||
print(f" Improvement: {temporal.get('improvement', 0):+.4f}")
|
||||
if expected_note:
|
||||
print(expected_note)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def log_queue_random_stats(user_id, shuffle_seed, sampler_stats, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print("[FINAL] Queue Random Statistics")
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Steps: {sampler_stats['total_steps']}")
|
||||
print(f" Total Refreshed: {sampler_stats['total_refreshed']} example sets")
|
||||
print(f" Avg Refresh per Step: {sampler_stats['avg_refresh_per_step']} (always full)")
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def log_queue_random_sampler_init(example_count, classes, queue_size, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[QueueRandomSampler] Initialized with {example_count} examples")
|
||||
print(f" Classes: {classes}")
|
||||
print(f" Queue size: {queue_size}")
|
||||
print(" Policy: ALL elements refreshed every step")
|
||||
|
||||
|
||||
def log_queue_random_queue_state(processed_count, user_id, queue_sampler, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[Sample {processed_count}] User {user_id} | QUEUE RANDOM Policy")
|
||||
print("Queue State (ALL FRESH RANDOM samples):")
|
||||
for idx, ex_idcs in enumerate(queue_sampler):
|
||||
print(f" [{idx}] Example indices: {ex_idcs}")
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def log_queue_random_result(processed_count, answer, ground_truth, is_correct, avg_confidence, consistency, cumulative_accuracy, cumulative_correct, window_accuracy, recent_results, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[RESULT] Sample {processed_count} | QUEUE RANDOM Policy")
|
||||
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
|
||||
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
|
||||
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
|
||||
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
|
||||
print(" [NOTE] Queue will be FULLY REFRESHED for next sample")
|
||||
print(f"{'='*60}\n")
|
||||
222
sc/hf_api.py
222
sc/hf_api.py
@@ -1,222 +0,0 @@
|
||||
"""
|
||||
HTTP API for HuggingFace causal LM inference (answer + logits).
|
||||
|
||||
This module exposes a small FastAPI service that:
|
||||
- Loads a local HuggingFace CausalLM once at startup
|
||||
- Accepts chat-style messages via HTTP
|
||||
- Returns generated text
|
||||
- Returns *logits-derived* information in a practical size:
|
||||
- top-k logprobs for each generated token (recommended)
|
||||
- optional prompt logits top-k for the final prompt position
|
||||
|
||||
Why not return full logits?
|
||||
Full logits are extremely large (seq_len x vocab_size) and will quickly
|
||||
overwhelm network/memory. This API defaults to returning top-k logprobs.
|
||||
|
||||
Run:
|
||||
MODEL_DIR=/path/to/model \\
|
||||
uvicorn sc.hf_api:app --host 0.0.0.0 --port 8000
|
||||
|
||||
Example:
|
||||
curl -X POST http://localhost:8000/generate \\
|
||||
-H 'Content-Type: application/json' \\
|
||||
-d '{"messages":[{"role":"user","content":"Hello!"}],"max_new_tokens":64,"top_k":10}'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
DEFAULT_MODEL_DIR = os.environ.get(
|
||||
"MODEL_DIR", "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
|
||||
)
|
||||
|
||||
app = FastAPI(title="HF LLM API", version="0.1.0")
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: str
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
max_new_tokens: int = Field(default=128, ge=1, le=1024)
|
||||
temperature: float = Field(default=0.0, ge=0.0, le=2.0)
|
||||
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
do_sample: bool = False
|
||||
top_k: int = Field(default=20, ge=1, le=200)
|
||||
|
||||
# If True, also returns prompt last-position top-k logits (not full matrix)
|
||||
include_prompt_topk: bool = False
|
||||
|
||||
|
||||
class TokenTopK(BaseModel):
|
||||
token_id: int
|
||||
token: str
|
||||
logprob: float
|
||||
|
||||
|
||||
class GeneratedStep(BaseModel):
|
||||
token_id: int
|
||||
token: str
|
||||
logprob: float
|
||||
topk: List[TokenTopK]
|
||||
|
||||
|
||||
class PromptTopK(BaseModel):
|
||||
position: int
|
||||
topk: List[TokenTopK]
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
prompt: str
|
||||
generated_text: str
|
||||
generated_token_ids: List[int]
|
||||
steps: List[GeneratedStep]
|
||||
prompt_topk: Optional[PromptTopK] = None
|
||||
|
||||
|
||||
def _get_device() -> str:
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _load_model_and_tokenizer(model_dir: str):
|
||||
device = _get_device()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
||||
device_map="auto" if device == "cuda" else None,
|
||||
).eval()
|
||||
return tokenizer, model, device
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def _startup_load():
|
||||
# Load once; shared across requests.
|
||||
global tokenizer, model, device
|
||||
tokenizer, model, device = _load_model_and_tokenizer(DEFAULT_MODEL_DIR)
|
||||
app.state.model_dir = DEFAULT_MODEL_DIR
|
||||
app.state.device = device
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> Dict[str, Any]:
|
||||
return {
|
||||
"ok": True,
|
||||
"model_dir": getattr(app.state, "model_dir", None),
|
||||
"device": getattr(app.state, "device", None),
|
||||
}
|
||||
|
||||
|
||||
def _apply_chat_template(messages: List[ChatMessage]) -> str:
|
||||
# Convert pydantic objects to plain dicts compatible with HF template.
|
||||
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
msg_dicts,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
except Exception:
|
||||
# Fallback: naive concatenation
|
||||
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in msg_dicts]) + "\nassistant:"
|
||||
return prompt
|
||||
|
||||
|
||||
def _topk_from_logits(logits_1d: torch.Tensor, top_k: int) -> List[TokenTopK]:
|
||||
# logits_1d: (vocab,)
|
||||
top_vals, top_ids = torch.topk(logits_1d, k=top_k)
|
||||
# Convert to logprobs for interpretability
|
||||
logprobs = torch.log_softmax(logits_1d, dim=-1)
|
||||
out: List[TokenTopK] = []
|
||||
for tid in top_ids.tolist():
|
||||
tok = tokenizer.decode([tid])
|
||||
out.append(
|
||||
TokenTopK(
|
||||
token_id=int(tid),
|
||||
token=tok,
|
||||
logprob=float(logprobs[tid].detach().cpu().item()),
|
||||
)
|
||||
)
|
||||
# Sort in descending logprob (topk preserves order, but be explicit)
|
||||
out.sort(key=lambda x: x.logprob, reverse=True)
|
||||
return out
|
||||
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
def generate(req: GenerateRequest) -> GenerateResponse:
|
||||
if not hasattr(app.state, "model_dir"):
|
||||
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
||||
|
||||
prompt = _apply_chat_template(req.messages)
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
if device == "cuda":
|
||||
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||
|
||||
# Use generate() so we can get per-step scores (logits).
|
||||
# output.scores is a list with length = generated_tokens
|
||||
# each element shape: (batch, vocab)
|
||||
with torch.no_grad():
|
||||
out = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=req.max_new_tokens,
|
||||
do_sample=req.do_sample if req.temperature > 0 else False,
|
||||
temperature=req.temperature if req.temperature > 0 else 1.0,
|
||||
top_p=req.top_p,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
# Generated token ids include prompt + new tokens
|
||||
seq = out.sequences[0]
|
||||
prompt_len = int(inputs["input_ids"].shape[1])
|
||||
gen_token_ids = seq[prompt_len:].tolist()
|
||||
generated_text = tokenizer.decode(gen_token_ids, skip_special_tokens=True)
|
||||
|
||||
# Build per-step top-k + chosen token logprob
|
||||
steps: List[GeneratedStep] = []
|
||||
if out.scores is None:
|
||||
raise HTTPException(status_code=500, detail="Model did not return scores")
|
||||
|
||||
for step_idx, step_logits in enumerate(out.scores):
|
||||
# step_logits: (1, vocab)
|
||||
step_logits_1d = step_logits[0]
|
||||
chosen_id = int(gen_token_ids[step_idx]) if step_idx < len(gen_token_ids) else None
|
||||
logprobs_1d = torch.log_softmax(step_logits_1d, dim=-1)
|
||||
chosen_logprob = float(logprobs_1d[chosen_id].detach().cpu().item()) if chosen_id is not None else float("nan")
|
||||
steps.append(
|
||||
GeneratedStep(
|
||||
token_id=chosen_id,
|
||||
token=tokenizer.decode([chosen_id]),
|
||||
logprob=chosen_logprob,
|
||||
topk=_topk_from_logits(step_logits_1d, req.top_k),
|
||||
)
|
||||
)
|
||||
|
||||
prompt_topk: Optional[PromptTopK] = None
|
||||
if req.include_prompt_topk:
|
||||
with torch.no_grad():
|
||||
forward = model(**inputs)
|
||||
# forward.logits: (1, seq_len, vocab)
|
||||
last_pos = int(forward.logits.shape[1] - 1)
|
||||
last_logits = forward.logits[0, -1, :]
|
||||
prompt_topk = PromptTopK(position=last_pos, topk=_topk_from_logits(last_logits, req.top_k))
|
||||
|
||||
return GenerateResponse(
|
||||
prompt=prompt,
|
||||
generated_text=generated_text,
|
||||
generated_token_ids=[int(t) for t in gen_token_ids],
|
||||
steps=steps,
|
||||
prompt_topk=prompt_topk,
|
||||
)
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_dir = "/mnt/sting/hjyoon/projects/llm/huggingface/gptoss20b"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
||||
device_map="auto" if device == "cuda" else None,
|
||||
).eval()
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
|
||||
]
|
||||
|
||||
# Convert chat messages -> a single prompt string using the model's chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(**inputs)
|
||||
|
||||
logits = out.logits # shape: (batch=1, seq_len, vocab_size)
|
||||
print("logits shape:", logits.shape)
|
||||
@@ -1 +0,0 @@
|
||||
# SC Preprocess Module
|
||||
@@ -1,409 +0,0 @@
|
||||
"""
|
||||
Data Shuffling Module for Sleep Stage Classification Experiment
|
||||
|
||||
This module provides functionality to shuffle user data for fair comparison
|
||||
across different Queue policies (Confidence, Consistency, Random).
|
||||
|
||||
Features:
|
||||
- Shuffle user data with fixed random seed for reproducibility
|
||||
- Preserve original indices for tracking
|
||||
- Save shuffled data to JSON for reuse
|
||||
- Ensure all 3 experiments use identical shuffled order
|
||||
|
||||
Usage:
|
||||
from sc.preprocess.shuffle_data import shuffle_user_data, load_shuffled_data
|
||||
|
||||
# Shuffle and save
|
||||
shuffled_data = shuffle_user_data(user_id=5, seed=42, data_path="...")
|
||||
|
||||
# Load existing shuffled data
|
||||
shuffled_data = load_shuffled_data(user_id=5, seed=42, output_dir="...")
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import datasets
|
||||
from glob import glob
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
|
||||
def shuffle_user_data(
|
||||
user_id: int,
|
||||
seed: int,
|
||||
data_path: str,
|
||||
output_dir: str = None,
|
||||
save_to_file: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Shuffle user data and optionally save to JSON file.
|
||||
|
||||
Args:
|
||||
user_id: User ID to process (e.g., 5 or 10)
|
||||
seed: Random seed for reproducibility
|
||||
data_path: Path to SleepEDF data directory
|
||||
output_dir: Directory to save shuffled data (default: data_path/shuffled)
|
||||
save_to_file: Whether to save shuffled data to JSON
|
||||
|
||||
Returns:
|
||||
List of shuffled samples with original_idx preserved
|
||||
"""
|
||||
# Set random seeds for reproducibility
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
# Format user_id with leading zeros (e.g., "05", "10")
|
||||
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
|
||||
|
||||
# Load user test data
|
||||
test_path = os.path.join(data_path, user_str, "2")
|
||||
if not os.path.exists(test_path):
|
||||
raise FileNotFoundError(f"Test data not found: {test_path}")
|
||||
|
||||
test_dataset = datasets.load_from_disk(test_path)
|
||||
|
||||
# Create list with original indices
|
||||
samples_with_idx = []
|
||||
for idx, sample in enumerate(test_dataset):
|
||||
sample_dict = {
|
||||
"original_idx": idx,
|
||||
"label": sample["label"],
|
||||
"features": sample["features"],
|
||||
}
|
||||
samples_with_idx.append(sample_dict)
|
||||
|
||||
# Shuffle the samples
|
||||
shuffled_samples = samples_with_idx.copy()
|
||||
random.shuffle(shuffled_samples)
|
||||
|
||||
# Add shuffled index
|
||||
for shuffled_idx, sample in enumerate(shuffled_samples):
|
||||
sample["shuffled_idx"] = shuffled_idx
|
||||
|
||||
# Save to file if requested
|
||||
if save_to_file:
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join(data_path, "shuffled")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
output_path = os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
|
||||
|
||||
metadata = {
|
||||
"user_id": user_str,
|
||||
"seed": seed,
|
||||
"total_samples": len(shuffled_samples),
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"data_path": data_path,
|
||||
}
|
||||
|
||||
# Count samples per stage
|
||||
stage_counts = {}
|
||||
for sample in shuffled_samples:
|
||||
label = sample["label"]
|
||||
stage_counts[label] = stage_counts.get(label, 0) + 1
|
||||
metadata["stage_distribution"] = stage_counts
|
||||
|
||||
output_data = {
|
||||
"metadata": metadata,
|
||||
"samples": shuffled_samples,
|
||||
}
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"[SHUFFLE] Saved shuffled data to: {output_path}")
|
||||
print(f" User: {user_str}, Seed: {seed}")
|
||||
print(f" Total samples: {len(shuffled_samples)}")
|
||||
print(f" Stage distribution: {stage_counts}")
|
||||
|
||||
return shuffled_samples
|
||||
|
||||
|
||||
def load_shuffled_data(
|
||||
user_id: int,
|
||||
seed: int,
|
||||
output_dir: str = None,
|
||||
data_path: str = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load existing shuffled data from JSON file.
|
||||
If file doesn't exist, create it by shuffling the data.
|
||||
|
||||
Args:
|
||||
user_id: User ID to load
|
||||
seed: Random seed used for shuffling
|
||||
output_dir: Directory containing shuffled data files
|
||||
data_path: Path to original data (used if shuffled file doesn't exist)
|
||||
|
||||
Returns:
|
||||
List of shuffled samples
|
||||
"""
|
||||
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
|
||||
|
||||
if output_dir is None and data_path is not None:
|
||||
output_dir = os.path.join(data_path, "shuffled")
|
||||
|
||||
if output_dir is None:
|
||||
raise ValueError("Either output_dir or data_path must be provided")
|
||||
|
||||
file_path = os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
|
||||
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
print(f"[SHUFFLE] Loaded existing shuffled data: {file_path}")
|
||||
print(f" Total samples: {data['metadata']['total_samples']}")
|
||||
return data["samples"]
|
||||
else:
|
||||
if data_path is None:
|
||||
raise FileNotFoundError(f"Shuffled data not found: {file_path}")
|
||||
|
||||
print(f"[SHUFFLE] Shuffled data not found, creating new...")
|
||||
return shuffle_user_data(
|
||||
user_id=user_id,
|
||||
seed=seed,
|
||||
data_path=data_path,
|
||||
output_dir=output_dir,
|
||||
save_to_file=True,
|
||||
)
|
||||
|
||||
|
||||
def get_shuffled_file_path(
|
||||
user_id: int,
|
||||
seed: int,
|
||||
data_path: str,
|
||||
) -> str:
|
||||
"""
|
||||
Get the path to the shuffled data file.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
seed: Random seed
|
||||
data_path: Base data path
|
||||
|
||||
Returns:
|
||||
Path to the shuffled data JSON file
|
||||
"""
|
||||
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
|
||||
output_dir = os.path.join(data_path, "shuffled")
|
||||
return os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
|
||||
|
||||
|
||||
def verify_shuffle_consistency(
|
||||
user_id: int,
|
||||
seeds: List[int],
|
||||
data_path: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that shuffled data files exist and are consistent.
|
||||
|
||||
Args:
|
||||
user_id: User ID to verify
|
||||
seeds: List of seeds to verify
|
||||
data_path: Base data path
|
||||
|
||||
Returns:
|
||||
True if all files exist and have same sample count
|
||||
"""
|
||||
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
|
||||
output_dir = os.path.join(data_path, "shuffled")
|
||||
|
||||
sample_counts = []
|
||||
for seed in seeds:
|
||||
file_path = os.path.join(output_dir, f"user{user_str}_seed{seed}.json")
|
||||
if not os.path.exists(file_path):
|
||||
print(f"[VERIFY] Missing: {file_path}")
|
||||
return False
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
sample_counts.append(data["metadata"]["total_samples"])
|
||||
|
||||
# All should have same count
|
||||
if len(set(sample_counts)) != 1:
|
||||
print(f"[VERIFY] Inconsistent sample counts: {sample_counts}")
|
||||
return False
|
||||
|
||||
print(f"[VERIFY] User {user_str}: All {len(seeds)} seed files verified ({sample_counts[0]} samples each)")
|
||||
return True
|
||||
|
||||
|
||||
class ShuffledDataLoader:
|
||||
"""
|
||||
DataLoader that uses pre-shuffled data for consistent experiment comparison.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
user_id: int,
|
||||
seed: int,
|
||||
example_pool: str = "out",
|
||||
):
|
||||
"""
|
||||
Initialize shuffled data loader.
|
||||
|
||||
Args:
|
||||
data_path: Path to SleepEDF data
|
||||
user_id: User ID (e.g., 5 or 10)
|
||||
seed: Shuffle seed
|
||||
example_pool: "out" (different users) or "in" (same user)
|
||||
"""
|
||||
self.data_path = data_path
|
||||
self.user_id = user_id
|
||||
self.seed = seed
|
||||
self.example_pool = example_pool
|
||||
|
||||
user_str = f"{user_id:02d}" if isinstance(user_id, int) else str(user_id)
|
||||
self.user_str = user_str
|
||||
|
||||
# Load metadata
|
||||
info_path = os.path.join(data_path, "info.json")
|
||||
if not os.path.exists(info_path):
|
||||
raise FileNotFoundError(f"Info file not found: {info_path}")
|
||||
|
||||
with open(info_path, "r", encoding="utf-8") as f:
|
||||
self.metadata = json.load(f)
|
||||
|
||||
# Load shuffled test data
|
||||
self.shuffled_samples = load_shuffled_data(
|
||||
user_id=user_id,
|
||||
seed=seed,
|
||||
data_path=data_path,
|
||||
)
|
||||
|
||||
# Load example dataset (same as original DataLoader)
|
||||
self.example_dataset = datasets.Dataset.from_list([])
|
||||
users = glob(os.path.join(data_path, "*"))
|
||||
users = [os.path.basename(p) for p in users if os.path.isdir(p)]
|
||||
users = [u for u in users if u not in ["info.json", "shuffled"]]
|
||||
|
||||
for user in users:
|
||||
if example_pool == "out" and user == user_str:
|
||||
continue
|
||||
if example_pool == "in" and user != user_str:
|
||||
continue
|
||||
|
||||
example_path = os.path.join(data_path, user, "1")
|
||||
if os.path.exists(example_path):
|
||||
user_dataset = datasets.load_from_disk(example_path)
|
||||
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
|
||||
|
||||
# Shuffle example dataset with fixed seed
|
||||
self.example_dataset = self.example_dataset.shuffle(seed=0)
|
||||
|
||||
print(f"[ShuffledDataLoader] User: {user_str}, Seed: {seed}")
|
||||
print(f" Test samples: {len(self.shuffled_samples)}")
|
||||
print(f" Example samples: {len(self.example_dataset)}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.shuffled_samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.shuffled_samples[idx]
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.shuffled_samples:
|
||||
yield sample
|
||||
|
||||
def get_examples(self):
|
||||
return self.example_dataset
|
||||
|
||||
def get_metadata(self):
|
||||
return self.metadata
|
||||
|
||||
def get_sensor_info(self):
|
||||
return self.metadata["feature"]
|
||||
|
||||
def get_task_info(self):
|
||||
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
|
||||
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
|
||||
classes_info = "\n".join(classes_info)
|
||||
task_info += f"**Classes**:\n{classes_info}"
|
||||
return task_info
|
||||
|
||||
def get_classes_info(self):
|
||||
return list(self.metadata["class"].keys())
|
||||
|
||||
|
||||
def prepare_all_shuffled_data(
|
||||
data_path: str,
|
||||
users: List[int] = [5, 10],
|
||||
seeds: List[int] = [42, 123, 456],
|
||||
) -> None:
|
||||
"""
|
||||
Prepare all shuffled data files for the experiment.
|
||||
|
||||
Args:
|
||||
data_path: Path to SleepEDF data
|
||||
users: List of user IDs
|
||||
seeds: List of shuffle seeds
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("Preparing Shuffled Data for Experiment")
|
||||
print("=" * 60)
|
||||
|
||||
for user_id in users:
|
||||
for seed in seeds:
|
||||
print(f"\nProcessing User {user_id}, Seed {seed}...")
|
||||
shuffle_user_data(
|
||||
user_id=user_id,
|
||||
seed=seed,
|
||||
data_path=data_path,
|
||||
save_to_file=True,
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Verification")
|
||||
print("=" * 60)
|
||||
|
||||
for user_id in users:
|
||||
verify_shuffle_consistency(user_id, seeds, data_path)
|
||||
|
||||
print("\n[DONE] All shuffled data prepared.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
def main(
|
||||
data_path: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
|
||||
users = "5,10",
|
||||
seeds = "42,123,456",
|
||||
):
|
||||
"""
|
||||
Prepare shuffled data for experiments.
|
||||
|
||||
Args:
|
||||
data_path: Path to SleepEDF data
|
||||
users: Comma-separated user IDs or list/tuple
|
||||
seeds: Comma-separated shuffle seeds or list/tuple
|
||||
"""
|
||||
# Handle both string and tuple/list inputs
|
||||
if isinstance(users, str):
|
||||
user_list = [int(u.strip()) for u in users.split(",")]
|
||||
elif isinstance(users, (list, tuple)):
|
||||
user_list = [int(u) for u in users]
|
||||
else:
|
||||
user_list = [int(users)]
|
||||
|
||||
if isinstance(seeds, str):
|
||||
seed_list = [int(s.strip()) for s in seeds.split(",")]
|
||||
elif isinstance(seeds, (list, tuple)):
|
||||
seed_list = [int(s) for s in seeds]
|
||||
else:
|
||||
seed_list = [int(seeds)]
|
||||
|
||||
prepare_all_shuffled_data(
|
||||
data_path=data_path,
|
||||
users=user_list,
|
||||
seeds=seed_list,
|
||||
)
|
||||
|
||||
fire.Fire(main)
|
||||
@@ -1,463 +0,0 @@
|
||||
"""
|
||||
Confidence-based Queue Policy Experiment Runner
|
||||
|
||||
This module implements the CONFIDENCE-based queue update policy:
|
||||
- Queue is updated based on model confidence scores
|
||||
- Higher confidence examples are retained in the queue
|
||||
- Lower confidence examples are evicted
|
||||
|
||||
Usage:
|
||||
python -m sc.run_confidence --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_confidence --user_id=10 --shuffle_seed=123
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import yaml
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sklearn.metrics import f1_score, precision_score, recall_score
|
||||
|
||||
from fire import Fire
|
||||
|
||||
# Add project root to path for relative imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
|
||||
from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
log = debug_log.log
|
||||
|
||||
|
||||
async def run_confidence_experiment(
|
||||
dataloader: ShuffledDataLoader,
|
||||
model_pool,
|
||||
config: Dict[str, Any],
|
||||
user_id: int,
|
||||
shuffle_seed: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run confidence-based queue policy experiment.
|
||||
|
||||
Queue Update Policy:
|
||||
- After each inference, rank queue elements by confidence score
|
||||
- Evict the lowest confidence element
|
||||
- Add a new random ICL example set
|
||||
|
||||
Args:
|
||||
dataloader: ShuffledDataLoader with pre-shuffled test data
|
||||
model_pool: Async model pool for LLM inference
|
||||
config: Experiment configuration
|
||||
user_id: User ID being processed
|
||||
shuffle_seed: Shuffle seed for reproducibility
|
||||
|
||||
Returns:
|
||||
List of result dictionaries
|
||||
"""
|
||||
# Set seeds for reproducibility within experiment
|
||||
random.seed(shuffle_seed)
|
||||
np.random.seed(shuffle_seed)
|
||||
|
||||
example_dataset = dataloader.get_examples()
|
||||
if len(example_dataset) == 0:
|
||||
log(f"[ERROR] No examples found for user {user_id}")
|
||||
return []
|
||||
|
||||
# Build class_indices for Queue initialization
|
||||
class_indices = {}
|
||||
for idx, example in enumerate(example_dataset):
|
||||
label = example["label"]
|
||||
if label not in class_indices:
|
||||
class_indices[label] = []
|
||||
class_indices[label].append(idx)
|
||||
|
||||
# Initialize Queue
|
||||
queue_size = config.get("queue_size", 5)
|
||||
ex_queue = Queue(class_indices, queue_size)
|
||||
|
||||
# Tracking variables
|
||||
results = []
|
||||
cumulative_correct = 0
|
||||
window_size = config.get("tracking_window", 20)
|
||||
recent_results = []
|
||||
confidence_history = []
|
||||
consistency_history = []
|
||||
|
||||
all_predictions = []
|
||||
all_ground_truths = []
|
||||
|
||||
debug_log.log_policy_experiment_header(
|
||||
"CONFIDENCE-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
len(dataloader),
|
||||
queue_size,
|
||||
)
|
||||
|
||||
for processed_count, sample in enumerate(dataloader):
|
||||
ex_queue.set_current_time(processed_count)
|
||||
|
||||
# Log queue state before processing
|
||||
debug_log.log_policy_queue_state("CONFIDENCE", processed_count, user_id, ex_queue)
|
||||
|
||||
# Create agent pool
|
||||
agent_pool = AgentPool(log_path=config["log_path"])
|
||||
|
||||
try:
|
||||
for queue_idx, ex_idcs in enumerate(ex_queue):
|
||||
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
|
||||
|
||||
agent = SCAgent(
|
||||
name="EEG sensing",
|
||||
index=queue_idx,
|
||||
model_pool=model_pool,
|
||||
task_info=dataloader.get_task_info(),
|
||||
classes_info=dataloader.get_classes_info(),
|
||||
sensor_info=dataloader.get_sensor_info(),
|
||||
sample=sample,
|
||||
examples=examples,
|
||||
log_path=config["log_path"],
|
||||
)
|
||||
agent_pool.add_agent(agent)
|
||||
except Exception as e:
|
||||
log(f"[ERROR] Failed to create agents: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if len(agent_pool.agents) == 0:
|
||||
log(f"[WARN] No agents created for sample {processed_count}")
|
||||
continue
|
||||
|
||||
# Run parallel interpretation
|
||||
try:
|
||||
interpretation_result = await agent_pool.run_parallel_interpretation()
|
||||
except Exception as e:
|
||||
log(f"[ERROR] Interpretation failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if interpretation_result is None:
|
||||
log(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
continue
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
|
||||
|
||||
# Evaluate result
|
||||
ground_truth = sample["label"]
|
||||
is_correct = answer == ground_truth
|
||||
|
||||
cumulative_correct += 1 if is_correct else 0
|
||||
cumulative_accuracy = cumulative_correct / (processed_count + 1)
|
||||
|
||||
recent_results.append(1 if is_correct else 0)
|
||||
if len(recent_results) > window_size:
|
||||
recent_results.pop(0)
|
||||
window_accuracy = sum(recent_results) / len(recent_results)
|
||||
|
||||
confidence_history.append(avg_confidence)
|
||||
consistency_history.append(consistency)
|
||||
|
||||
all_predictions.append(answer)
|
||||
all_ground_truths.append(ground_truth)
|
||||
|
||||
# Performance logging
|
||||
debug_log.log_policy_result(
|
||||
"CONFIDENCE",
|
||||
processed_count,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
)
|
||||
|
||||
# CONFIDENCE-BASED Queue Update
|
||||
if responses:
|
||||
# Build confidence map from agent responses
|
||||
confidence_map = {}
|
||||
for idx, response in responses.items():
|
||||
confidence_map[idx] = response.get("CONFIDENCE", 0.0)
|
||||
|
||||
debug_log.log_confidence_map(responses, confidence_map)
|
||||
|
||||
# Update queue by confidence (evict lowest, add new random)
|
||||
ex_queue.update_by_confidence(confidence_map)
|
||||
ex_queue.increment_usage(list(responses.keys()))
|
||||
|
||||
debug_log.log_policy_queue_state_after("Confidence", processed_count, ex_queue)
|
||||
|
||||
# Store result
|
||||
result = {
|
||||
"sample_idx": processed_count,
|
||||
"original_idx": sample.get("original_idx", processed_count),
|
||||
"shuffled_idx": sample.get("shuffled_idx", processed_count),
|
||||
"answer": answer,
|
||||
"ground_truth": ground_truth,
|
||||
"is_correct": is_correct,
|
||||
"confidence": avg_confidence,
|
||||
"consistency": consistency,
|
||||
"cumulative_accuracy": cumulative_accuracy,
|
||||
"window_accuracy": window_accuracy,
|
||||
"experiment_type": "confidence",
|
||||
"user_id": user_id,
|
||||
"shuffle_seed": shuffle_seed,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
# Final statistics
|
||||
survival_summary = ex_queue.get_survival_summary()
|
||||
debug_log.log_final_policy_survival("CONFIDENCE", user_id, shuffle_seed, survival_summary)
|
||||
|
||||
if results:
|
||||
results[-1]["queue_survival_stats"] = survival_summary
|
||||
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
|
||||
"""Compute comprehensive experiment statistics."""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
# Overall metrics
|
||||
correct = sum(1 for r in results if r.get("is_correct", False))
|
||||
total = len(results)
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
|
||||
# Confidence and consistency averages
|
||||
confidences = [r.get("confidence", 0) for r in results]
|
||||
consistencies = [r.get("consistency", 0) for r in results]
|
||||
avg_confidence = np.mean(confidences) if confidences else 0
|
||||
avg_consistency = np.mean(consistencies) if consistencies else 0
|
||||
|
||||
# Per-stage accuracy
|
||||
stage_correct = {}
|
||||
stage_total = {}
|
||||
for r in results:
|
||||
gt = r.get("ground_truth", "UNKNOWN")
|
||||
stage_total[gt] = stage_total.get(gt, 0) + 1
|
||||
if r.get("is_correct", False):
|
||||
stage_correct[gt] = stage_correct.get(gt, 0) + 1
|
||||
|
||||
stage_accuracy = {}
|
||||
for stage in stages:
|
||||
if stage in stage_total:
|
||||
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
|
||||
else:
|
||||
stage_accuracy[stage] = 0.0
|
||||
|
||||
# F1 Score and Macro metrics
|
||||
predictions = [r.get("answer", "") for r in results]
|
||||
ground_truths = [r.get("ground_truth", "") for r in results]
|
||||
|
||||
try:
|
||||
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
except Exception:
|
||||
macro_f1 = macro_precision = macro_recall = 0.0
|
||||
|
||||
# Temporal analysis
|
||||
mid_point = len(results) // 2
|
||||
if mid_point > 0:
|
||||
first_half = results[:mid_point]
|
||||
second_half = results[mid_point:]
|
||||
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
|
||||
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
|
||||
improvement = second_half_acc - first_half_acc
|
||||
else:
|
||||
first_half_acc = second_half_acc = accuracy
|
||||
improvement = 0
|
||||
|
||||
# Learning curve (every 10 samples)
|
||||
learning_curve = []
|
||||
for i in range(0, len(results), 10):
|
||||
chunk = results[:i+10]
|
||||
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
|
||||
learning_curve.append({
|
||||
"sample_idx": min(i+10, len(results)),
|
||||
"cumulative_accuracy": chunk_acc,
|
||||
})
|
||||
|
||||
# Convergence speed (90% of final accuracy)
|
||||
final_accuracy = accuracy
|
||||
convergence_threshold = 0.9 * final_accuracy
|
||||
convergence_idx = None
|
||||
running_correct = 0
|
||||
for i, r in enumerate(results):
|
||||
running_correct += 1 if r.get("is_correct", False) else 0
|
||||
running_acc = running_correct / (i + 1)
|
||||
if running_acc >= convergence_threshold:
|
||||
convergence_idx = i
|
||||
break
|
||||
|
||||
# Queue statistics
|
||||
queue_stats = {}
|
||||
for r in results:
|
||||
if "queue_survival_stats" in r:
|
||||
queue_stats = r["queue_survival_stats"]
|
||||
break
|
||||
|
||||
return {
|
||||
"experiment_type": "confidence",
|
||||
"user_id": results[0].get("user_id") if results else None,
|
||||
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
|
||||
"total_samples": total,
|
||||
"correct": correct,
|
||||
"accuracy": accuracy,
|
||||
"avg_confidence": float(avg_confidence),
|
||||
"avg_consistency": float(avg_consistency),
|
||||
"stage_accuracy": stage_accuracy,
|
||||
"stage_sample_counts": stage_total,
|
||||
"macro_f1": float(macro_f1),
|
||||
"macro_precision": float(macro_precision),
|
||||
"macro_recall": float(macro_recall),
|
||||
"temporal_analysis": {
|
||||
"first_half_accuracy": first_half_acc,
|
||||
"second_half_accuracy": second_half_acc,
|
||||
"improvement": improvement,
|
||||
},
|
||||
"learning_curve": learning_curve,
|
||||
"convergence_idx": convergence_idx,
|
||||
"queue_stats": queue_stats,
|
||||
}
|
||||
|
||||
|
||||
def save_results(
|
||||
results: List[Dict[str, Any]],
|
||||
stats: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
user_id: int,
|
||||
shuffle_seed: int,
|
||||
) -> None:
|
||||
"""Save experiment results and statistics."""
|
||||
log_path = config["log_path"]
|
||||
|
||||
# Create user/seed specific directory
|
||||
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save statistics
|
||||
stats_path = os.path.join(output_dir, "statistics.json")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||||
log(f"[SAVE] Statistics: {stats_path}")
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(output_dir, "results.json")
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
log(f"[SAVE] Results: {results_path}")
|
||||
|
||||
# Save config
|
||||
config_path = os.path.join(output_dir, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
log(f"[SAVE] Config: {config_path}")
|
||||
|
||||
|
||||
def main(
|
||||
config_path: str = "sc/config/experiment_confidence.yaml",
|
||||
user_id: int = 5,
|
||||
shuffle_seed: int = 42,
|
||||
) -> None:
|
||||
"""
|
||||
Run Confidence-based Queue Policy experiment.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
user_id: User ID to process (5 or 10)
|
||||
shuffle_seed: Shuffle seed for data order (42, 123, or 456)
|
||||
|
||||
Example:
|
||||
python -m sc.run_confidence --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_confidence --user_id=10 --shuffle_seed=123
|
||||
"""
|
||||
log(f"[MAIN] Loading config: {config_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
# Override with CLI arguments
|
||||
config["user_id"] = user_id
|
||||
config["shuffle_seed"] = shuffle_seed
|
||||
|
||||
# Create unique log path
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
|
||||
os.makedirs(config["log_path"], exist_ok=True)
|
||||
|
||||
# Print experiment info
|
||||
debug_log.log_policy_main_header(
|
||||
"CONFIDENCE-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
config.get("queue_size", 5),
|
||||
len(config.get("models", [])),
|
||||
config["log_path"],
|
||||
)
|
||||
|
||||
# Load shuffled data
|
||||
debug_log.log_policy_loading_data(user_id, shuffle_seed)
|
||||
dataloader = ShuffledDataLoader(
|
||||
data_path=config["data_path"],
|
||||
user_id=user_id,
|
||||
seed=shuffle_seed,
|
||||
example_pool=config.get("example_pool", "out"),
|
||||
)
|
||||
|
||||
# Load models
|
||||
debug_log.log_policy_loading_models()
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
num_ctx=config.get("num_ctx", 15000),
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
debug_log.log_policy_start(label="experiment")
|
||||
results = asyncio.run(run_confidence_experiment(
|
||||
dataloader=dataloader,
|
||||
model_pool=model_pool,
|
||||
config=config,
|
||||
user_id=user_id,
|
||||
shuffle_seed=shuffle_seed,
|
||||
))
|
||||
|
||||
# Compute statistics
|
||||
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
|
||||
stats = compute_statistics(results, stages)
|
||||
|
||||
# Print final summary
|
||||
temporal = stats.get("temporal_analysis", {})
|
||||
debug_log.log_policy_complete_summary(
|
||||
"CONFIDENCE",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
stats,
|
||||
stats.get("stage_accuracy", {}),
|
||||
stats.get("stage_sample_counts", {}),
|
||||
temporal,
|
||||
)
|
||||
|
||||
# Save results
|
||||
save_results(results, stats, config, user_id, shuffle_seed)
|
||||
log(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
@@ -1,473 +0,0 @@
|
||||
"""
|
||||
Consistency-based Queue Policy Experiment Runner
|
||||
|
||||
This module implements the CONSISTENCY-based queue update policy:
|
||||
- Queue is updated based on SC consensus/agreement scores
|
||||
- Higher consistency (more agents agree) examples are retained
|
||||
- Lower consistency examples are evicted
|
||||
|
||||
Consistency Score = (Number of agents agreeing with majority) / (Total agents)
|
||||
|
||||
Usage:
|
||||
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import yaml
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sklearn.metrics import f1_score, precision_score, recall_score
|
||||
|
||||
from fire import Fire
|
||||
|
||||
# Add project root to path for relative imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
|
||||
from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
log = debug_log.log
|
||||
|
||||
|
||||
async def run_consistency_experiment(
|
||||
dataloader: ShuffledDataLoader,
|
||||
model_pool,
|
||||
config: Dict[str, Any],
|
||||
user_id: int,
|
||||
shuffle_seed: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run consistency-based queue policy experiment.
|
||||
|
||||
Queue Update Policy:
|
||||
- After each inference, calculate per-agent consistency
|
||||
- Consistency = how many other agents agree with this agent's answer
|
||||
- Rank queue elements by consistency score
|
||||
- Evict the lowest consistency element
|
||||
- Add a new random ICL example set
|
||||
|
||||
Args:
|
||||
dataloader: ShuffledDataLoader with pre-shuffled test data
|
||||
model_pool: Async model pool for LLM inference
|
||||
config: Experiment configuration
|
||||
user_id: User ID being processed
|
||||
shuffle_seed: Shuffle seed for reproducibility
|
||||
|
||||
Returns:
|
||||
List of result dictionaries
|
||||
"""
|
||||
# Set seeds for reproducibility within experiment
|
||||
random.seed(shuffle_seed)
|
||||
np.random.seed(shuffle_seed)
|
||||
|
||||
example_dataset = dataloader.get_examples()
|
||||
if len(example_dataset) == 0:
|
||||
log(f"[ERROR] No examples found for user {user_id}")
|
||||
return []
|
||||
|
||||
# Build class_indices for Queue initialization
|
||||
class_indices = {}
|
||||
for idx, example in enumerate(example_dataset):
|
||||
label = example["label"]
|
||||
if label not in class_indices:
|
||||
class_indices[label] = []
|
||||
class_indices[label].append(idx)
|
||||
|
||||
# Initialize Queue
|
||||
queue_size = config.get("queue_size", 5)
|
||||
ex_queue = Queue(class_indices, queue_size)
|
||||
|
||||
# Tracking variables
|
||||
results = []
|
||||
cumulative_correct = 0
|
||||
window_size = config.get("tracking_window", 20)
|
||||
recent_results = []
|
||||
confidence_history = []
|
||||
consistency_history = []
|
||||
|
||||
all_predictions = []
|
||||
all_ground_truths = []
|
||||
|
||||
debug_log.log_policy_experiment_header(
|
||||
"CONSISTENCY-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
len(dataloader),
|
||||
queue_size,
|
||||
)
|
||||
|
||||
for processed_count, sample in enumerate(dataloader):
|
||||
ex_queue.set_current_time(processed_count)
|
||||
|
||||
# Log queue state before processing
|
||||
debug_log.log_policy_queue_state("CONSISTENCY", processed_count, user_id, ex_queue)
|
||||
|
||||
# Create agent pool
|
||||
agent_pool = AgentPool(log_path=config["log_path"])
|
||||
|
||||
try:
|
||||
for queue_idx, ex_idcs in enumerate(ex_queue):
|
||||
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
|
||||
|
||||
agent = SCAgent(
|
||||
name="EEG sensing",
|
||||
index=queue_idx,
|
||||
model_pool=model_pool,
|
||||
task_info=dataloader.get_task_info(),
|
||||
classes_info=dataloader.get_classes_info(),
|
||||
sensor_info=dataloader.get_sensor_info(),
|
||||
sample=sample,
|
||||
examples=examples,
|
||||
log_path=config["log_path"],
|
||||
)
|
||||
agent_pool.add_agent(agent)
|
||||
except Exception as e:
|
||||
log(f"[ERROR] Failed to create agents: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if len(agent_pool.agents) == 0:
|
||||
log(f"[WARN] No agents created for sample {processed_count}")
|
||||
continue
|
||||
|
||||
# Run parallel interpretation
|
||||
try:
|
||||
interpretation_result = await agent_pool.run_parallel_interpretation()
|
||||
except Exception as e:
|
||||
log(f"[ERROR] Interpretation failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if interpretation_result is None:
|
||||
log(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
continue
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
|
||||
|
||||
# Evaluate result
|
||||
ground_truth = sample["label"]
|
||||
is_correct = answer == ground_truth
|
||||
|
||||
cumulative_correct += 1 if is_correct else 0
|
||||
cumulative_accuracy = cumulative_correct / (processed_count + 1)
|
||||
|
||||
recent_results.append(1 if is_correct else 0)
|
||||
if len(recent_results) > window_size:
|
||||
recent_results.pop(0)
|
||||
window_accuracy = sum(recent_results) / len(recent_results)
|
||||
|
||||
confidence_history.append(avg_confidence)
|
||||
consistency_history.append(consistency)
|
||||
|
||||
all_predictions.append(answer)
|
||||
all_ground_truths.append(ground_truth)
|
||||
|
||||
# Performance logging
|
||||
debug_log.log_policy_result(
|
||||
"CONSISTENCY",
|
||||
processed_count,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
)
|
||||
|
||||
# CONSISTENCY-BASED Queue Update
|
||||
if responses:
|
||||
# Calculate per-agent consistency: how many other agents agree
|
||||
all_answers = [r.get("ANSWER") for r in responses.values()]
|
||||
consistency_map = {}
|
||||
|
||||
for idx, response in responses.items():
|
||||
agent_answer = response.get("ANSWER")
|
||||
# Consistency = ratio of agents (including self) that agree
|
||||
agreement_count = all_answers.count(agent_answer)
|
||||
agent_consistency = agreement_count / len(all_answers) if all_answers else 0
|
||||
consistency_map[idx] = agent_consistency
|
||||
|
||||
debug_log.log_consistency_map(responses, consistency_map)
|
||||
|
||||
# Update queue by consistency (reuses confidence method with consistency scores)
|
||||
ex_queue.update_by_confidence(consistency_map)
|
||||
ex_queue.increment_usage(list(responses.keys()))
|
||||
|
||||
debug_log.log_policy_queue_state_after("Consistency", processed_count, ex_queue)
|
||||
|
||||
# Store result
|
||||
result = {
|
||||
"sample_idx": processed_count,
|
||||
"original_idx": sample.get("original_idx", processed_count),
|
||||
"shuffled_idx": sample.get("shuffled_idx", processed_count),
|
||||
"answer": answer,
|
||||
"ground_truth": ground_truth,
|
||||
"is_correct": is_correct,
|
||||
"confidence": avg_confidence,
|
||||
"consistency": consistency,
|
||||
"cumulative_accuracy": cumulative_accuracy,
|
||||
"window_accuracy": window_accuracy,
|
||||
"experiment_type": "consistency",
|
||||
"user_id": user_id,
|
||||
"shuffle_seed": shuffle_seed,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
# Final statistics
|
||||
survival_summary = ex_queue.get_survival_summary()
|
||||
debug_log.log_final_policy_survival("CONSISTENCY", user_id, shuffle_seed, survival_summary)
|
||||
|
||||
if results:
|
||||
results[-1]["queue_survival_stats"] = survival_summary
|
||||
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
|
||||
"""Compute comprehensive experiment statistics."""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
# Overall metrics
|
||||
correct = sum(1 for r in results if r.get("is_correct", False))
|
||||
total = len(results)
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
|
||||
# Confidence and consistency averages
|
||||
confidences = [r.get("confidence", 0) for r in results]
|
||||
consistencies = [r.get("consistency", 0) for r in results]
|
||||
avg_confidence = np.mean(confidences) if confidences else 0
|
||||
avg_consistency = np.mean(consistencies) if consistencies else 0
|
||||
|
||||
# Per-stage accuracy
|
||||
stage_correct = {}
|
||||
stage_total = {}
|
||||
for r in results:
|
||||
gt = r.get("ground_truth", "UNKNOWN")
|
||||
stage_total[gt] = stage_total.get(gt, 0) + 1
|
||||
if r.get("is_correct", False):
|
||||
stage_correct[gt] = stage_correct.get(gt, 0) + 1
|
||||
|
||||
stage_accuracy = {}
|
||||
for stage in stages:
|
||||
if stage in stage_total:
|
||||
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
|
||||
else:
|
||||
stage_accuracy[stage] = 0.0
|
||||
|
||||
# F1 Score and Macro metrics
|
||||
predictions = [r.get("answer", "") for r in results]
|
||||
ground_truths = [r.get("ground_truth", "") for r in results]
|
||||
|
||||
try:
|
||||
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
except Exception:
|
||||
macro_f1 = macro_precision = macro_recall = 0.0
|
||||
|
||||
# Temporal analysis
|
||||
mid_point = len(results) // 2
|
||||
if mid_point > 0:
|
||||
first_half = results[:mid_point]
|
||||
second_half = results[mid_point:]
|
||||
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
|
||||
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
|
||||
improvement = second_half_acc - first_half_acc
|
||||
else:
|
||||
first_half_acc = second_half_acc = accuracy
|
||||
improvement = 0
|
||||
|
||||
# Learning curve (every 10 samples)
|
||||
learning_curve = []
|
||||
for i in range(0, len(results), 10):
|
||||
chunk = results[:i+10]
|
||||
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
|
||||
learning_curve.append({
|
||||
"sample_idx": min(i+10, len(results)),
|
||||
"cumulative_accuracy": chunk_acc,
|
||||
})
|
||||
|
||||
# Convergence speed (90% of final accuracy)
|
||||
final_accuracy = accuracy
|
||||
convergence_threshold = 0.9 * final_accuracy
|
||||
convergence_idx = None
|
||||
running_correct = 0
|
||||
for i, r in enumerate(results):
|
||||
running_correct += 1 if r.get("is_correct", False) else 0
|
||||
running_acc = running_correct / (i + 1)
|
||||
if running_acc >= convergence_threshold:
|
||||
convergence_idx = i
|
||||
break
|
||||
|
||||
# Queue statistics
|
||||
queue_stats = {}
|
||||
for r in results:
|
||||
if "queue_survival_stats" in r:
|
||||
queue_stats = r["queue_survival_stats"]
|
||||
break
|
||||
|
||||
return {
|
||||
"experiment_type": "consistency",
|
||||
"user_id": results[0].get("user_id") if results else None,
|
||||
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
|
||||
"total_samples": total,
|
||||
"correct": correct,
|
||||
"accuracy": accuracy,
|
||||
"avg_confidence": float(avg_confidence),
|
||||
"avg_consistency": float(avg_consistency),
|
||||
"stage_accuracy": stage_accuracy,
|
||||
"stage_sample_counts": stage_total,
|
||||
"macro_f1": float(macro_f1),
|
||||
"macro_precision": float(macro_precision),
|
||||
"macro_recall": float(macro_recall),
|
||||
"temporal_analysis": {
|
||||
"first_half_accuracy": first_half_acc,
|
||||
"second_half_accuracy": second_half_acc,
|
||||
"improvement": improvement,
|
||||
},
|
||||
"learning_curve": learning_curve,
|
||||
"convergence_idx": convergence_idx,
|
||||
"queue_stats": queue_stats,
|
||||
}
|
||||
|
||||
|
||||
def save_results(
|
||||
results: List[Dict[str, Any]],
|
||||
stats: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
user_id: int,
|
||||
shuffle_seed: int,
|
||||
) -> None:
|
||||
"""Save experiment results and statistics."""
|
||||
log_path = config["log_path"]
|
||||
|
||||
# Create user/seed specific directory
|
||||
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save statistics
|
||||
stats_path = os.path.join(output_dir, "statistics.json")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||||
log(f"[SAVE] Statistics: {stats_path}")
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(output_dir, "results.json")
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
log(f"[SAVE] Results: {results_path}")
|
||||
|
||||
# Save config
|
||||
config_path = os.path.join(output_dir, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
log(f"[SAVE] Config: {config_path}")
|
||||
|
||||
|
||||
def main(
|
||||
config_path: str = "sc/config/experiment_consistency.yaml",
|
||||
user_id: int = 5,
|
||||
shuffle_seed: int = 42,
|
||||
) -> None:
|
||||
"""
|
||||
Run Consistency-based Queue Policy experiment.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
user_id: User ID to process (5 or 10)
|
||||
shuffle_seed: Shuffle seed for data order (42, 123, or 456)
|
||||
|
||||
Example:
|
||||
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
|
||||
"""
|
||||
log(f"[MAIN] Loading config: {config_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
# Override with CLI arguments
|
||||
config["user_id"] = user_id
|
||||
config["shuffle_seed"] = shuffle_seed
|
||||
|
||||
# Create unique log path
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
|
||||
os.makedirs(config["log_path"], exist_ok=True)
|
||||
|
||||
# Print experiment info
|
||||
debug_log.log_policy_main_header(
|
||||
"CONSISTENCY-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
config.get("queue_size", 5),
|
||||
len(config.get("models", [])),
|
||||
config["log_path"],
|
||||
)
|
||||
|
||||
# Load shuffled data
|
||||
debug_log.log_policy_loading_data(user_id, shuffle_seed)
|
||||
dataloader = ShuffledDataLoader(
|
||||
data_path=config["data_path"],
|
||||
user_id=user_id,
|
||||
seed=shuffle_seed,
|
||||
example_pool=config.get("example_pool", "out"),
|
||||
)
|
||||
|
||||
# Load models
|
||||
debug_log.log_policy_loading_models()
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
num_ctx=config.get("num_ctx", 15000),
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
debug_log.log_policy_start(label="experiment")
|
||||
results = asyncio.run(run_consistency_experiment(
|
||||
dataloader=dataloader,
|
||||
model_pool=model_pool,
|
||||
config=config,
|
||||
user_id=user_id,
|
||||
shuffle_seed=shuffle_seed,
|
||||
))
|
||||
|
||||
# Compute statistics
|
||||
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
|
||||
stats = compute_statistics(results, stages)
|
||||
|
||||
# Print final summary
|
||||
temporal = stats.get("temporal_analysis", {})
|
||||
debug_log.log_policy_complete_summary(
|
||||
"CONSISTENCY",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
stats,
|
||||
stats.get("stage_accuracy", {}),
|
||||
stats.get("stage_sample_counts", {}),
|
||||
temporal,
|
||||
)
|
||||
|
||||
# Save results
|
||||
save_results(results, stats, config, user_id, shuffle_seed)
|
||||
log(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
467
sc/run_sc.py
467
sc/run_sc.py
@@ -1,467 +0,0 @@
|
||||
"""
|
||||
Self-Consistency Experiment Runner for Sleep Stage Classification
|
||||
|
||||
This module implements Self-Consistency methodology for sleep stage classification:
|
||||
- Sample N times with the same prompt
|
||||
- Use majority voting for final answer
|
||||
- Support confidence-based tie-breaking
|
||||
|
||||
Usage:
|
||||
python -m sc.run_sc sc/config/sleepedf_sc.yaml
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import yaml
|
||||
import json
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
import time
|
||||
|
||||
from glob import glob
|
||||
from fire import Fire
|
||||
|
||||
# Add project root to path for relative imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sc.core.data_loader import DataLoader
|
||||
from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
|
||||
async def run_single_task(
|
||||
seed: int,
|
||||
dataloader: DataLoader,
|
||||
model_pool,
|
||||
config: Dict[str, Any],
|
||||
user_id: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a single classification task with Self-Consistency.
|
||||
|
||||
Args:
|
||||
task: Task dictionary containing sample, examples, and metadata
|
||||
model_pool: Async model pool for LLM inference
|
||||
num_sc_samples: Number of Self-Consistency samples
|
||||
|
||||
Returns:
|
||||
Result dictionary with answer, confidence, consistency, and metrics
|
||||
"""
|
||||
np.random.seed(seed)
|
||||
|
||||
example_dataset = dataloader.get_examples()
|
||||
if len(example_dataset) == 0:
|
||||
debug_log.warn_no_examples()
|
||||
return []
|
||||
|
||||
# Build class_indices: Dict[str, List[int]] for Queue initialization
|
||||
class_indices = {}
|
||||
for idx, example in enumerate(example_dataset):
|
||||
label = example["label"]
|
||||
if label not in class_indices:
|
||||
class_indices[label] = []
|
||||
class_indices[label].append(idx)
|
||||
|
||||
ex_queue = Queue(class_indices, config["queue_size"])
|
||||
sample_rate = config.get("sample_rate", 1)
|
||||
results = []
|
||||
cumulative_correct = 0
|
||||
window_size = config.get("tracking_window", 10) # Recent window size for accuracy tracking
|
||||
recent_results = [] # Windowed recent results for accuracy calculation
|
||||
confidence_history = [] # Confidence history for tracking
|
||||
processed_count = 0
|
||||
for i, sample in enumerate(dataloader):
|
||||
if sample_rate > 1 and i % sample_rate != 0:
|
||||
continue
|
||||
|
||||
ex_queue.set_current_time(processed_count) # Track current sample index for survival time
|
||||
|
||||
user_info = f"User: {user_id}" if user_id else "Unknown User"
|
||||
debug_log.log_queue_state_before(processed_count, user_info, ex_queue, config)
|
||||
|
||||
agent_pool = AgentPool(log_path=config["log_path"])
|
||||
for queue_idx, ex_idcs in enumerate(ex_queue):
|
||||
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
|
||||
task_info = dataloader.get_task_info()
|
||||
sensor_info = dataloader.get_sensor_info()
|
||||
classes_info = dataloader.get_classes_info()
|
||||
agent = SCAgent(
|
||||
name="EEG sensing",
|
||||
index=queue_idx,
|
||||
model_pool=model_pool,
|
||||
task_info=task_info,
|
||||
classes_info=classes_info,
|
||||
sensor_info=sensor_info,
|
||||
sample=sample,
|
||||
examples=examples,
|
||||
log_path=config["log_path"],
|
||||
)
|
||||
agent_pool.add_agent(agent)
|
||||
|
||||
# Check if any agents were added
|
||||
if len(agent_pool.agents) == 0:
|
||||
debug_log.warn_no_agents(processed_count)
|
||||
continue
|
||||
|
||||
interpretation_result = await agent_pool.run_parallel_interpretation()
|
||||
|
||||
# Handle case where interpretation failed
|
||||
if interpretation_result is None:
|
||||
debug_log.warn_interpretation_failed(processed_count)
|
||||
continue
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
|
||||
|
||||
ground_truth = sample["label"]
|
||||
is_correct = answer == ground_truth
|
||||
|
||||
cumulative_correct += 1 if is_correct else 0
|
||||
cumulative_accuracy = cumulative_correct / (processed_count + 1)
|
||||
|
||||
recent_results.append(1 if is_correct else 0)
|
||||
if len(recent_results) > window_size:
|
||||
recent_results.pop(0)
|
||||
window_accuracy = sum(recent_results) / len(recent_results)
|
||||
|
||||
# Confidence history
|
||||
confidence_history.append(avg_confidence)
|
||||
avg_confidence_so_far = sum(confidence_history) / len(confidence_history)
|
||||
|
||||
debug_log.log_tracking(
|
||||
processed_count,
|
||||
user_info,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
avg_confidence_so_far,
|
||||
config,
|
||||
)
|
||||
|
||||
# Update queue based on Confidence (Priority Queue)
|
||||
if responses:
|
||||
# Create confidence map {queue_idx: confidence}
|
||||
confidence_map = {idx: r.get("CONFIDENCE", 0) for idx, r in responses.items()}
|
||||
ex_queue.update_by_confidence(confidence_map)
|
||||
ex_queue.increment_usage(list(responses.keys()))
|
||||
|
||||
debug_log.log_queue_state_after(processed_count, ex_queue, config)
|
||||
elif queue_idcs:
|
||||
debug_log.warn_no_responses(processed_count)
|
||||
|
||||
result = {
|
||||
"sample_idx": processed_count,
|
||||
"answer": answer,
|
||||
"ground_truth": ground_truth,
|
||||
"is_correct": is_correct,
|
||||
"confidence": avg_confidence,
|
||||
"consistency": consistency,
|
||||
"cumulative_accuracy": cumulative_accuracy,
|
||||
"window_accuracy": window_accuracy,
|
||||
"avg_confidence_so_far": avg_confidence_so_far,
|
||||
}
|
||||
results.append(result)
|
||||
processed_count += 1
|
||||
|
||||
# Final Queue survival statistics
|
||||
survival_summary = ex_queue.get_survival_summary()
|
||||
debug_log.log_final_queue_stats(user_info, survival_summary, config)
|
||||
|
||||
# Add Queue statistics to results (in the last result)
|
||||
if results:
|
||||
results[-1]["queue_survival_stats"] = survival_summary
|
||||
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def run_parallel(
|
||||
config: Dict[str, Any],
|
||||
model_pool,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute all classification tasks in parallel.
|
||||
|
||||
Args:
|
||||
tasks: List of task dictionaries
|
||||
model_pool: Async model pool for LLM inference
|
||||
config: Experiment configuration
|
||||
|
||||
Returns:
|
||||
List of result dictionaries
|
||||
"""
|
||||
data_path = config["data_path"]
|
||||
user_paths = glob(os.path.join(data_path, "*"))
|
||||
# Filter to only include directories (exclude files like info.json)
|
||||
users = [os.path.basename(p) for p in user_paths if os.path.isdir(p) and os.path.basename(p) != "info.json"]
|
||||
if not users:
|
||||
debug_log.warn_no_user_dirs(data_path)
|
||||
return []
|
||||
debug_log.log_found_users(users, config)
|
||||
seeds = range(config.get("num_seeds", 1))
|
||||
|
||||
tasks = []
|
||||
for user in users[:1]: # <- User Selection for testing
|
||||
# for user in users:
|
||||
for seed in seeds:
|
||||
np.random.seed(seed)
|
||||
dataloader = DataLoader(
|
||||
data_path=data_path,
|
||||
user_id=user,
|
||||
example_pool=config.get("example_pool", "out"),
|
||||
continuous=config.get("continuous", True),
|
||||
)
|
||||
# Check if dataloader was properly initialized
|
||||
if not hasattr(dataloader, 'test_dataset') or len(dataloader) == 0:
|
||||
debug_log.warn_skip_user_no_test_data(user)
|
||||
continue
|
||||
if len(dataloader.get_examples()) == 0:
|
||||
debug_log.warn_skip_user_no_example_data(user)
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_single_task(
|
||||
seed,
|
||||
dataloader,
|
||||
model_pool,
|
||||
config,
|
||||
user_id=user,
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return results
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Statistics and Results
|
||||
# =============================================================================
|
||||
|
||||
def compute_statistics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Compute experiment statistics from results.
|
||||
|
||||
Args:
|
||||
results: List of result dictionaries
|
||||
|
||||
Returns:
|
||||
Statistics dictionary containing accuracy, confidence, consistency metrics
|
||||
"""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
# Overall accuracy
|
||||
correct = sum(1 for r in results if r.get("is_correct", False))
|
||||
total = len(results)
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
|
||||
# Average confidence
|
||||
confidences = [r.get("confidence", 0) for r in results]
|
||||
avg_confidence = np.mean(confidences) if confidences else 0
|
||||
|
||||
# Average consistency
|
||||
consistencies = [r.get("consistency", 0) for r in results]
|
||||
avg_consistency = np.mean(consistencies) if consistencies else 0
|
||||
|
||||
# Per-class accuracy
|
||||
class_correct = {}
|
||||
class_total = {}
|
||||
for r in results:
|
||||
gt = r.get("ground_truth", "UNKNOWN")
|
||||
class_total[gt] = class_total.get(gt, 0) + 1
|
||||
if r.get("is_correct", False):
|
||||
class_correct[gt] = class_correct.get(gt, 0) + 1
|
||||
|
||||
class_accuracy = {
|
||||
cls: class_correct.get(cls, 0) / class_total[cls]
|
||||
for cls in class_total
|
||||
}
|
||||
|
||||
# High consistency accuracy analysis
|
||||
high_consistency_results = [r for r in results if r.get("consistency", 0) >= 0.8]
|
||||
high_consistency_accuracy = (
|
||||
sum(1 for r in high_consistency_results if r.get("is_correct", False))
|
||||
/ len(high_consistency_results)
|
||||
) if high_consistency_results else 0
|
||||
|
||||
# ================================================================
|
||||
# TIME DEBUGGING
|
||||
mid_point = len(results) // 2
|
||||
if mid_point > 0:
|
||||
first_half = results[:mid_point]
|
||||
second_half = results[mid_point:]
|
||||
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
|
||||
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
|
||||
improvement = second_half_acc - first_half_acc
|
||||
else:
|
||||
first_half_acc = 0
|
||||
second_half_acc = 0
|
||||
improvement = 0
|
||||
|
||||
quartile_accs = []
|
||||
quartile_size = len(results) // 4 if len(results) >= 4 else len(results)
|
||||
if quartile_size > 0:
|
||||
for q in range(4):
|
||||
start = q * quartile_size
|
||||
end = start + quartile_size if q < 3 else len(results)
|
||||
quartile = results[start:end]
|
||||
if quartile:
|
||||
q_acc = sum(1 for r in quartile if r.get("is_correct", False)) / len(quartile)
|
||||
quartile_accs.append(q_acc)
|
||||
|
||||
if len(confidences) > 1:
|
||||
first_half_conf = np.mean(confidences[:mid_point]) if mid_point > 0 else 0
|
||||
second_half_conf = np.mean(confidences[mid_point:]) if mid_point > 0 else 0
|
||||
conf_improvement = second_half_conf - first_half_conf
|
||||
else:
|
||||
first_half_conf = avg_confidence
|
||||
second_half_conf = avg_confidence
|
||||
conf_improvement = 0
|
||||
|
||||
queue_survival_stats = {}
|
||||
for r in results:
|
||||
if "queue_survival_stats" in r:
|
||||
queue_survival_stats = r["queue_survival_stats"]
|
||||
break
|
||||
# ================================================================
|
||||
|
||||
|
||||
return {
|
||||
"total_samples": total,
|
||||
"correct": correct,
|
||||
"accuracy": accuracy,
|
||||
"avg_confidence": float(avg_confidence),
|
||||
"avg_consistency": float(avg_consistency),
|
||||
"class_accuracy": class_accuracy,
|
||||
"high_consistency_accuracy": high_consistency_accuracy,
|
||||
"high_consistency_samples": len(high_consistency_results),
|
||||
"temporal_analysis": {
|
||||
"first_half_accuracy": first_half_acc,
|
||||
"second_half_accuracy": second_half_acc,
|
||||
"accuracy_improvement": improvement,
|
||||
"quartile_accuracies": quartile_accs,
|
||||
"first_half_confidence": float(first_half_conf),
|
||||
"second_half_confidence": float(second_half_conf),
|
||||
"confidence_improvement": float(conf_improvement),
|
||||
},
|
||||
"queue_stats": queue_survival_stats,
|
||||
}
|
||||
|
||||
|
||||
def save_results(
|
||||
results: List[Dict[str, Any]],
|
||||
stats: Dict[str, Any],
|
||||
config: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Save experiment results and statistics to files.
|
||||
|
||||
Args:
|
||||
results: List of result dictionaries
|
||||
stats: Statistics dictionary
|
||||
config: Experiment configuration
|
||||
"""
|
||||
log_path = config["log_path"]
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
|
||||
# Save statistics
|
||||
stats_path = os.path.join(log_path, "statistics.json")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||||
debug_log.log_save_statistics(stats_path, config)
|
||||
|
||||
# Save all results
|
||||
results_to_save = []
|
||||
for r in results:
|
||||
r_copy = {k: v for k, v in r.items() if k != "all_responses"}
|
||||
results_to_save.append(r_copy)
|
||||
|
||||
results_path = os.path.join(log_path, "all_results.json")
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results_to_save, f, indent=2, ensure_ascii=False)
|
||||
debug_log.log_save_results(results_path, config)
|
||||
|
||||
# Save configuration for reproducibility
|
||||
config_path = os.path.join(log_path, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
debug_log.log_save_config(config_path, config)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI Commands
|
||||
# =============================================================================
|
||||
|
||||
def main(config_path: str) -> None:
|
||||
"""
|
||||
Run experiment.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
|
||||
Example:
|
||||
python -m sc.run_sc sc/config/sleepedf_sc.yaml
|
||||
"""
|
||||
debug_log.log_main_loading_config(config_path)
|
||||
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||
|
||||
# Add timestamp to log path for unique experiment runs
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if "log_path" in config:
|
||||
config["log_path"] = f"{config['log_path']}_{timestamp}"
|
||||
|
||||
# Print experiment configuration
|
||||
debug_log.log_main_config(config, config.get("debug", True))
|
||||
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
num_ctx=config.get("num_ctx", 15000),
|
||||
)
|
||||
|
||||
# Run experiments
|
||||
debug_log.log_main_start(config)
|
||||
all_results = asyncio.run(run_parallel(config, model_pool))
|
||||
|
||||
# Flatten results: run_parallel returns list of lists (one per user/seed)
|
||||
# Each element is either a list of results or an exception
|
||||
flattened_results = []
|
||||
for result in all_results:
|
||||
if isinstance(result, Exception):
|
||||
debug_log.error_task_failed(result)
|
||||
continue
|
||||
if isinstance(result, list):
|
||||
flattened_results.extend(result)
|
||||
else:
|
||||
flattened_results.append(result)
|
||||
|
||||
debug_log.log_total_results(len(flattened_results), config)
|
||||
|
||||
# Compute and display statistics
|
||||
stats = compute_statistics(flattened_results)
|
||||
debug_log.log_experiment_results(stats, config)
|
||||
|
||||
# Time-based analysis output
|
||||
temporal = stats.get("temporal_analysis", {})
|
||||
debug_log.log_temporal_analysis(temporal, config)
|
||||
|
||||
queue_stats = stats.get("queue_stats", {})
|
||||
debug_log.log_queue_stats(queue_stats, config)
|
||||
|
||||
# Save results
|
||||
save_results(flattened_results, stats, config)
|
||||
debug_log.log_results_saved(config["log_path"], config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
@@ -1,557 +0,0 @@
|
||||
"""
|
||||
Queue Random Baseline Experiment Runner
|
||||
|
||||
This module implements the QUEUE RANDOM BASELINE policy:
|
||||
- Queue structure is maintained (size=5)
|
||||
- But every step, ALL 5 queue elements are replaced with fresh random samples
|
||||
- This tests whether improvements come from queue structure itself
|
||||
or from the cumulative learning effect
|
||||
|
||||
Purpose:
|
||||
- Ablation study to distinguish:
|
||||
1. Benefits from using 5 ICL examples (queue structure)
|
||||
2. Benefits from retaining good examples over time (cumulative learning)
|
||||
|
||||
Key Difference from Other Policies:
|
||||
- Confidence/Consistency: Queue updated by evicting lowest scoring examples
|
||||
- Pure Random (no queue): Samples fresh examples each time, no structure
|
||||
- Queue Random (this): Queue structure exists, but fully refreshed each step
|
||||
|
||||
Usage:
|
||||
python -m sc.run_sc_queue_random --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_sc_queue_random --user_id=10 --shuffle_seed=123
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import yaml
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sklearn.metrics import f1_score, precision_score, recall_score
|
||||
|
||||
from fire import Fire
|
||||
|
||||
# Add project root to path for relative imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
|
||||
from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
log = debug_log.log
|
||||
|
||||
|
||||
class QueueRandomSampler:
|
||||
"""
|
||||
Queue-based Random Sampler - maintains queue structure but refreshes all elements each step.
|
||||
|
||||
This sampler:
|
||||
- Maintains a queue of size N (same as Confidence/Consistency policies)
|
||||
- BUT replaces ALL N elements with fresh random samples at each step
|
||||
- No memory/accumulation between steps
|
||||
- Provides baseline to test if queue structure itself helps
|
||||
"""
|
||||
|
||||
def __init__(self, example_dataset, queue_size: int = 5, seed: int = None):
|
||||
"""
|
||||
Initialize Queue Random sampler.
|
||||
|
||||
Args:
|
||||
example_dataset: Dataset containing all available examples
|
||||
queue_size: Size of the queue (number of ICL example sets)
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
self.example_dataset = example_dataset
|
||||
self.queue_size = queue_size
|
||||
self.base_seed = seed
|
||||
|
||||
# Build class indices for balanced sampling
|
||||
self.class_indices = {}
|
||||
for idx, example in enumerate(example_dataset):
|
||||
label = example["label"]
|
||||
if label not in self.class_indices:
|
||||
self.class_indices[label] = []
|
||||
self.class_indices[label].append(idx)
|
||||
|
||||
self.classes = list(self.class_indices.keys())
|
||||
self._step_counter = 0
|
||||
|
||||
# Initialize queue with random samples
|
||||
self._queue = self._sample_all_new()
|
||||
|
||||
# Tracking for statistics (mimics Queue class interface)
|
||||
self._total_refreshed = 0
|
||||
|
||||
debug_log.log_queue_random_sampler_init(
|
||||
len(example_dataset),
|
||||
self.classes,
|
||||
queue_size,
|
||||
)
|
||||
|
||||
def _sample_one_set(self) -> List[int]:
|
||||
"""Sample one ICL example set (one example per class)."""
|
||||
example_set = []
|
||||
for cls in self.classes:
|
||||
if self.class_indices[cls]:
|
||||
idx = random.choice(self.class_indices[cls])
|
||||
example_set.append(idx)
|
||||
return example_set
|
||||
|
||||
def _sample_all_new(self) -> List[List[int]]:
|
||||
"""Sample completely new queue contents."""
|
||||
return [self._sample_one_set() for _ in range(self.queue_size)]
|
||||
|
||||
def refresh_all(self) -> None:
|
||||
"""Replace ALL queue elements with fresh random samples."""
|
||||
# Set seed for reproducibility but with variety per step
|
||||
if self.base_seed is not None:
|
||||
random.seed(self.base_seed + self._step_counter * 1000)
|
||||
|
||||
self._queue = self._sample_all_new()
|
||||
self._total_refreshed += self.queue_size
|
||||
self._step_counter += 1
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over queue contents (mimics Queue interface)."""
|
||||
return iter(self._queue)
|
||||
|
||||
def __len__(self):
|
||||
"""Return queue size."""
|
||||
return len(self._queue)
|
||||
|
||||
def get_queue_state(self) -> List[List[int]]:
|
||||
"""Get current queue state."""
|
||||
return self._queue.copy()
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get sampling statistics."""
|
||||
return {
|
||||
"total_steps": self._step_counter,
|
||||
"total_refreshed": self._total_refreshed,
|
||||
"queue_size": self.queue_size,
|
||||
"avg_refresh_per_step": self.queue_size, # Always full refresh
|
||||
"policy": "queue_random",
|
||||
}
|
||||
|
||||
|
||||
async def run_queue_random_experiment(
|
||||
dataloader: ShuffledDataLoader,
|
||||
model_pool,
|
||||
config: Dict[str, Any],
|
||||
user_id: int,
|
||||
shuffle_seed: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run Queue Random baseline experiment.
|
||||
|
||||
Queue Policy:
|
||||
- Every step, ALL queue elements are replaced with fresh random samples
|
||||
- No cumulative learning or retention of good examples
|
||||
- Tests whether queue structure alone provides benefits
|
||||
|
||||
Args:
|
||||
dataloader: ShuffledDataLoader with pre-shuffled test data
|
||||
model_pool: Async model pool for LLM inference
|
||||
config: Experiment configuration
|
||||
user_id: User ID being processed
|
||||
shuffle_seed: Shuffle seed for reproducibility
|
||||
|
||||
Returns:
|
||||
List of result dictionaries
|
||||
"""
|
||||
# Set seeds for reproducibility
|
||||
random.seed(shuffle_seed)
|
||||
np.random.seed(shuffle_seed)
|
||||
|
||||
example_dataset = dataloader.get_examples()
|
||||
if len(example_dataset) == 0:
|
||||
log(f"[ERROR] No examples found for user {user_id}")
|
||||
return []
|
||||
|
||||
# Initialize Queue Random Sampler (instead of regular Queue)
|
||||
queue_size = config.get("queue_size", 5)
|
||||
queue_sampler = QueueRandomSampler(
|
||||
example_dataset=example_dataset,
|
||||
queue_size=queue_size,
|
||||
seed=shuffle_seed,
|
||||
)
|
||||
|
||||
# Tracking variables
|
||||
results = []
|
||||
cumulative_correct = 0
|
||||
window_size = config.get("tracking_window", 20)
|
||||
recent_results = []
|
||||
confidence_history = []
|
||||
consistency_history = []
|
||||
|
||||
all_predictions = []
|
||||
all_ground_truths = []
|
||||
|
||||
debug_log.log_policy_experiment_header(
|
||||
"QUEUE RANDOM BASELINE",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
len(dataloader),
|
||||
queue_size,
|
||||
policy_note=f"Policy: ALL {queue_size} queue elements refreshed EVERY step",
|
||||
)
|
||||
|
||||
for processed_count, sample in enumerate(dataloader):
|
||||
# CRITICAL: Refresh ALL queue elements before each inference
|
||||
queue_sampler.refresh_all()
|
||||
|
||||
# Log queue state (all new random samples)
|
||||
debug_log.log_queue_random_queue_state(processed_count, user_id, queue_sampler)
|
||||
|
||||
# Create agent pool
|
||||
agent_pool = AgentPool(log_path=config["log_path"])
|
||||
|
||||
try:
|
||||
for queue_idx, ex_idcs in enumerate(queue_sampler):
|
||||
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
|
||||
|
||||
agent = SCAgent(
|
||||
name="EEG sensing",
|
||||
index=queue_idx,
|
||||
model_pool=model_pool,
|
||||
task_info=dataloader.get_task_info(),
|
||||
classes_info=dataloader.get_classes_info(),
|
||||
sensor_info=dataloader.get_sensor_info(),
|
||||
sample=sample,
|
||||
examples=examples,
|
||||
log_path=config["log_path"],
|
||||
)
|
||||
agent_pool.add_agent(agent)
|
||||
except Exception as e:
|
||||
log(f"[ERROR] Failed to create agents: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if len(agent_pool.agents) == 0:
|
||||
log(f"[WARN] No agents created for sample {processed_count}")
|
||||
continue
|
||||
|
||||
# Run parallel interpretation
|
||||
try:
|
||||
interpretation_result = await agent_pool.run_parallel_interpretation()
|
||||
except Exception as e:
|
||||
log(f"[ERROR] Interpretation failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if interpretation_result is None:
|
||||
log(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
continue
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
|
||||
|
||||
# Evaluate result
|
||||
ground_truth = sample["label"]
|
||||
is_correct = answer == ground_truth
|
||||
|
||||
cumulative_correct += 1 if is_correct else 0
|
||||
cumulative_accuracy = cumulative_correct / (processed_count + 1)
|
||||
|
||||
recent_results.append(1 if is_correct else 0)
|
||||
if len(recent_results) > window_size:
|
||||
recent_results.pop(0)
|
||||
window_accuracy = sum(recent_results) / len(recent_results)
|
||||
|
||||
confidence_history.append(avg_confidence)
|
||||
consistency_history.append(consistency)
|
||||
|
||||
all_predictions.append(answer)
|
||||
all_ground_truths.append(ground_truth)
|
||||
|
||||
# Performance logging
|
||||
debug_log.log_queue_random_result(
|
||||
processed_count,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
)
|
||||
|
||||
# NO QUEUE UPDATE based on scores - just fresh random next time
|
||||
# (This is the key difference from Confidence/Consistency policies)
|
||||
|
||||
# Store result
|
||||
result = {
|
||||
"sample_idx": processed_count,
|
||||
"original_idx": sample.get("original_idx", processed_count),
|
||||
"shuffled_idx": sample.get("shuffled_idx", processed_count),
|
||||
"answer": answer,
|
||||
"ground_truth": ground_truth,
|
||||
"is_correct": is_correct,
|
||||
"confidence": avg_confidence,
|
||||
"consistency": consistency,
|
||||
"cumulative_accuracy": cumulative_accuracy,
|
||||
"window_accuracy": window_accuracy,
|
||||
"experiment_type": "queue_random",
|
||||
"user_id": user_id,
|
||||
"shuffle_seed": shuffle_seed,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
# Final statistics
|
||||
sampler_stats = queue_sampler.get_statistics()
|
||||
debug_log.log_queue_random_stats(user_id, shuffle_seed, sampler_stats)
|
||||
|
||||
if results:
|
||||
results[-1]["queue_random_stats"] = sampler_stats
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
|
||||
"""Compute comprehensive experiment statistics."""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
# Overall metrics
|
||||
correct = sum(1 for r in results if r.get("is_correct", False))
|
||||
total = len(results)
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
|
||||
# Confidence and consistency averages
|
||||
confidences = [r.get("confidence", 0) for r in results]
|
||||
consistencies = [r.get("consistency", 0) for r in results]
|
||||
avg_confidence = np.mean(confidences) if confidences else 0
|
||||
avg_consistency = np.mean(consistencies) if consistencies else 0
|
||||
|
||||
# Per-stage accuracy
|
||||
stage_correct = {}
|
||||
stage_total = {}
|
||||
for r in results:
|
||||
gt = r.get("ground_truth", "UNKNOWN")
|
||||
stage_total[gt] = stage_total.get(gt, 0) + 1
|
||||
if r.get("is_correct", False):
|
||||
stage_correct[gt] = stage_correct.get(gt, 0) + 1
|
||||
|
||||
stage_accuracy = {}
|
||||
for stage in stages:
|
||||
if stage in stage_total:
|
||||
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
|
||||
else:
|
||||
stage_accuracy[stage] = 0.0
|
||||
|
||||
# F1 Score and Macro metrics
|
||||
predictions = [r.get("answer", "") for r in results]
|
||||
ground_truths = [r.get("ground_truth", "") for r in results]
|
||||
|
||||
try:
|
||||
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
except Exception:
|
||||
macro_f1 = macro_precision = macro_recall = 0.0
|
||||
|
||||
# Temporal analysis
|
||||
mid_point = len(results) // 2
|
||||
if mid_point > 0:
|
||||
first_half = results[:mid_point]
|
||||
second_half = results[mid_point:]
|
||||
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
|
||||
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
|
||||
improvement = second_half_acc - first_half_acc
|
||||
else:
|
||||
first_half_acc = second_half_acc = accuracy
|
||||
improvement = 0
|
||||
|
||||
# Learning curve (every 10 samples)
|
||||
learning_curve = []
|
||||
for i in range(0, len(results), 10):
|
||||
chunk = results[:i+10]
|
||||
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
|
||||
learning_curve.append({
|
||||
"sample_idx": min(i+10, len(results)),
|
||||
"cumulative_accuracy": chunk_acc,
|
||||
})
|
||||
|
||||
# Convergence speed
|
||||
final_accuracy = accuracy
|
||||
convergence_threshold = 0.9 * final_accuracy if final_accuracy > 0 else 0.5
|
||||
convergence_idx = None
|
||||
running_correct = 0
|
||||
for i, r in enumerate(results):
|
||||
running_correct += 1 if r.get("is_correct", False) else 0
|
||||
running_acc = running_correct / (i + 1)
|
||||
if running_acc >= convergence_threshold:
|
||||
convergence_idx = i
|
||||
break
|
||||
|
||||
# Queue random statistics
|
||||
queue_stats = {}
|
||||
for r in results:
|
||||
if "queue_random_stats" in r:
|
||||
queue_stats = r["queue_random_stats"]
|
||||
break
|
||||
|
||||
return {
|
||||
"experiment_type": "queue_random",
|
||||
"user_id": results[0].get("user_id") if results else None,
|
||||
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
|
||||
"total_samples": total,
|
||||
"correct": correct,
|
||||
"accuracy": accuracy,
|
||||
"avg_confidence": float(avg_confidence),
|
||||
"avg_consistency": float(avg_consistency),
|
||||
"stage_accuracy": stage_accuracy,
|
||||
"stage_sample_counts": stage_total,
|
||||
"macro_f1": float(macro_f1),
|
||||
"macro_precision": float(macro_precision),
|
||||
"macro_recall": float(macro_recall),
|
||||
"temporal_analysis": {
|
||||
"first_half_accuracy": first_half_acc,
|
||||
"second_half_accuracy": second_half_acc,
|
||||
"improvement": improvement,
|
||||
},
|
||||
"learning_curve": learning_curve,
|
||||
"convergence_idx": convergence_idx,
|
||||
"queue_stats": queue_stats,
|
||||
}
|
||||
|
||||
|
||||
def save_results(
|
||||
results: List[Dict[str, Any]],
|
||||
stats: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
user_id: int,
|
||||
shuffle_seed: int,
|
||||
) -> None:
|
||||
"""Save experiment results and statistics."""
|
||||
log_path = config["log_path"]
|
||||
|
||||
# Create user/seed specific directory
|
||||
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save statistics
|
||||
stats_path = os.path.join(output_dir, "statistics.json")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||||
log(f"[SAVE] Statistics: {stats_path}")
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(output_dir, "results.json")
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
log(f"[SAVE] Results: {results_path}")
|
||||
|
||||
# Save config
|
||||
config_path = os.path.join(output_dir, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
log(f"[SAVE] Config: {config_path}")
|
||||
|
||||
|
||||
def main(
|
||||
config_path: str = "sc/config/experiment_queue_random.yaml",
|
||||
user_id: int = 5,
|
||||
shuffle_seed: int = 42,
|
||||
) -> None:
|
||||
"""
|
||||
Run Queue Random baseline experiment.
|
||||
|
||||
This baseline maintains queue structure but refreshes ALL elements each step.
|
||||
Tests whether performance gains come from:
|
||||
1. Queue structure itself (5 ICL examples)
|
||||
2. Cumulative learning (retaining good examples over time)
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
user_id: User ID to process (5, 10, or 15)
|
||||
shuffle_seed: Shuffle seed for data order (42 or 123)
|
||||
|
||||
Example:
|
||||
python -m sc.run_sc_queue_random --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_sc_queue_random --user_id=15 --shuffle_seed=123
|
||||
"""
|
||||
log(f"[MAIN] Loading config: {config_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
# Override with CLI arguments
|
||||
config["user_id"] = user_id
|
||||
config["shuffle_seed"] = shuffle_seed
|
||||
|
||||
# Create unique log path
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
|
||||
os.makedirs(config["log_path"], exist_ok=True)
|
||||
|
||||
# Print experiment info
|
||||
debug_log.log_policy_main_header(
|
||||
"QUEUE RANDOM BASELINE",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
config.get("queue_size", 5),
|
||||
len(config.get("models", [])),
|
||||
config["log_path"],
|
||||
policy_note=" Policy: Queue Random (full refresh every step)",
|
||||
)
|
||||
|
||||
# Load shuffled data
|
||||
debug_log.log_policy_loading_data(user_id, shuffle_seed)
|
||||
dataloader = ShuffledDataLoader(
|
||||
data_path=config["data_path"],
|
||||
user_id=user_id,
|
||||
seed=shuffle_seed,
|
||||
example_pool=config.get("example_pool", "out"),
|
||||
)
|
||||
|
||||
# Load models
|
||||
debug_log.log_policy_loading_models()
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
num_ctx=config.get("num_ctx", 15000),
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
debug_log.log_policy_start(label="Queue Random experiment")
|
||||
results = asyncio.run(run_queue_random_experiment(
|
||||
dataloader=dataloader,
|
||||
model_pool=model_pool,
|
||||
config=config,
|
||||
user_id=user_id,
|
||||
shuffle_seed=shuffle_seed,
|
||||
))
|
||||
|
||||
# Compute statistics
|
||||
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
|
||||
stats = compute_statistics(results, stages)
|
||||
|
||||
# Print final summary
|
||||
temporal = stats.get("temporal_analysis", {})
|
||||
debug_log.log_policy_complete_summary(
|
||||
"QUEUE RANDOM BASELINE",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
stats,
|
||||
stats.get("stage_accuracy", {}),
|
||||
stats.get("stage_sample_counts", {}),
|
||||
temporal,
|
||||
expected_note="\n [EXPECTED] Improvement should be ~0 (no cumulative learning)",
|
||||
)
|
||||
|
||||
# Save results
|
||||
save_results(results, stats, config, user_id, shuffle_seed)
|
||||
log(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
473
sc/run_usc.py
473
sc/run_usc.py
@@ -1,473 +0,0 @@
|
||||
"""
|
||||
Consistency-based Queue Policy Experiment Runner
|
||||
|
||||
This module implements the CONSISTENCY-based queue update policy:
|
||||
- Queue is updated based on SC consensus/agreement scores
|
||||
- Higher consistency (more agents agree) examples are retained
|
||||
- Lower consistency examples are evicted
|
||||
|
||||
Consistency Score = (Number of agents agreeing with majority) / (Total agents)
|
||||
|
||||
Usage:
|
||||
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import yaml
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sklearn.metrics import f1_score, precision_score, recall_score
|
||||
|
||||
from fire import Fire
|
||||
|
||||
# Add project root to path for relative imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
|
||||
from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
log = debug_log.log
|
||||
|
||||
|
||||
async def run_consistency_experiment(
|
||||
dataloader: ShuffledDataLoader,
|
||||
model_pool,
|
||||
config: Dict[str, Any],
|
||||
user_id: int,
|
||||
shuffle_seed: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run consistency-based queue policy experiment.
|
||||
|
||||
Queue Update Policy:
|
||||
- After each inference, calculate per-agent consistency
|
||||
- Consistency = how many other agents agree with this agent's answer
|
||||
- Rank queue elements by consistency score
|
||||
- Evict the lowest consistency element
|
||||
- Add a new random ICL example set
|
||||
|
||||
Args:
|
||||
dataloader: ShuffledDataLoader with pre-shuffled test data
|
||||
model_pool: Async model pool for LLM inference
|
||||
config: Experiment configuration
|
||||
user_id: User ID being processed
|
||||
shuffle_seed: Shuffle seed for reproducibility
|
||||
|
||||
Returns:
|
||||
List of result dictionaries
|
||||
"""
|
||||
# Set seeds for reproducibility within experiment
|
||||
random.seed(shuffle_seed)
|
||||
np.random.seed(shuffle_seed)
|
||||
|
||||
example_dataset = dataloader.get_examples()
|
||||
if len(example_dataset) == 0:
|
||||
log(f"[ERROR] No examples found for user {user_id}")
|
||||
return []
|
||||
|
||||
# Build class_indices for Queue initialization
|
||||
class_indices = {}
|
||||
for idx, example in enumerate(example_dataset):
|
||||
label = example["label"]
|
||||
if label not in class_indices:
|
||||
class_indices[label] = []
|
||||
class_indices[label].append(idx)
|
||||
|
||||
# Initialize Queue
|
||||
queue_size = config.get("queue_size", 5)
|
||||
ex_queue = Queue(class_indices, queue_size)
|
||||
|
||||
# Tracking variables
|
||||
results = []
|
||||
cumulative_correct = 0
|
||||
window_size = config.get("tracking_window", 20)
|
||||
recent_results = []
|
||||
confidence_history = []
|
||||
consistency_history = []
|
||||
|
||||
all_predictions = []
|
||||
all_ground_truths = []
|
||||
|
||||
debug_log.log_policy_experiment_header(
|
||||
"CONSISTENCY-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
len(dataloader),
|
||||
queue_size,
|
||||
)
|
||||
|
||||
for processed_count, sample in enumerate(dataloader):
|
||||
ex_queue.set_current_time(processed_count)
|
||||
|
||||
# Log queue state before processing
|
||||
debug_log.log_policy_queue_state("CONSISTENCY", processed_count, user_id, ex_queue)
|
||||
|
||||
# Create agent pool
|
||||
agent_pool = AgentPool(log_path=config["log_path"])
|
||||
|
||||
try:
|
||||
for queue_idx, ex_idcs in enumerate(ex_queue):
|
||||
examples = [example_dataset[ex_idx] for ex_idx in ex_idcs]
|
||||
|
||||
agent = SCAgent(
|
||||
name="EEG sensing",
|
||||
index=queue_idx,
|
||||
model_pool=model_pool,
|
||||
task_info=dataloader.get_task_info(),
|
||||
classes_info=dataloader.get_classes_info(),
|
||||
sensor_info=dataloader.get_sensor_info(),
|
||||
sample=sample,
|
||||
examples=examples,
|
||||
log_path=config["log_path"],
|
||||
)
|
||||
agent_pool.add_agent(agent)
|
||||
except Exception as e:
|
||||
log(f"[ERROR] Failed to create agents: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if len(agent_pool.agents) == 0:
|
||||
log(f"[WARN] No agents created for sample {processed_count}")
|
||||
continue
|
||||
|
||||
# Run parallel interpretation
|
||||
try:
|
||||
interpretation_result = await agent_pool.run_parallel_interpretation()
|
||||
except Exception as e:
|
||||
log(f"[ERROR] Interpretation failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if interpretation_result is None:
|
||||
log(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
continue
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
|
||||
|
||||
# Evaluate result
|
||||
ground_truth = sample["label"]
|
||||
is_correct = answer == ground_truth
|
||||
|
||||
cumulative_correct += 1 if is_correct else 0
|
||||
cumulative_accuracy = cumulative_correct / (processed_count + 1)
|
||||
|
||||
recent_results.append(1 if is_correct else 0)
|
||||
if len(recent_results) > window_size:
|
||||
recent_results.pop(0)
|
||||
window_accuracy = sum(recent_results) / len(recent_results)
|
||||
|
||||
confidence_history.append(avg_confidence)
|
||||
consistency_history.append(consistency)
|
||||
|
||||
all_predictions.append(answer)
|
||||
all_ground_truths.append(ground_truth)
|
||||
|
||||
# Performance logging
|
||||
debug_log.log_policy_result(
|
||||
"CONSISTENCY",
|
||||
processed_count,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
)
|
||||
|
||||
# CONSISTENCY-BASED Queue Update
|
||||
if responses:
|
||||
# Calculate per-agent consistency: how many other agents agree
|
||||
all_answers = [r.get("ANSWER") for r in responses.values()]
|
||||
consistency_map = {}
|
||||
|
||||
for idx, response in responses.items():
|
||||
agent_answer = response.get("ANSWER")
|
||||
# Consistency = ratio of agents (including self) that agree
|
||||
agreement_count = all_answers.count(agent_answer)
|
||||
agent_consistency = agreement_count / len(all_answers) if all_answers else 0
|
||||
consistency_map[idx] = agent_consistency
|
||||
|
||||
debug_log.log_consistency_map(responses, consistency_map)
|
||||
|
||||
# Update queue by consistency (reuses confidence method with consistency scores)
|
||||
ex_queue.update_by_confidence(consistency_map)
|
||||
ex_queue.increment_usage(list(responses.keys()))
|
||||
|
||||
debug_log.log_policy_queue_state_after("Consistency", processed_count, ex_queue)
|
||||
|
||||
# Store result
|
||||
result = {
|
||||
"sample_idx": processed_count,
|
||||
"original_idx": sample.get("original_idx", processed_count),
|
||||
"shuffled_idx": sample.get("shuffled_idx", processed_count),
|
||||
"answer": answer,
|
||||
"ground_truth": ground_truth,
|
||||
"is_correct": is_correct,
|
||||
"confidence": avg_confidence,
|
||||
"consistency": consistency,
|
||||
"cumulative_accuracy": cumulative_accuracy,
|
||||
"window_accuracy": window_accuracy,
|
||||
"experiment_type": "usc",
|
||||
"user_id": user_id,
|
||||
"shuffle_seed": shuffle_seed,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
# Final statistics
|
||||
survival_summary = ex_queue.get_survival_summary()
|
||||
debug_log.log_final_policy_survival("usc", user_id, shuffle_seed, survival_summary)
|
||||
|
||||
if results:
|
||||
results[-1]["queue_survival_stats"] = survival_summary
|
||||
results[-1]["queue_survival_details"] = ex_queue.get_survival_stats()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compute_statistics(results: List[Dict[str, Any]], stages: List[str]) -> Dict[str, Any]:
|
||||
"""Compute comprehensive experiment statistics."""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
# Overall metrics
|
||||
correct = sum(1 for r in results if r.get("is_correct", False))
|
||||
total = len(results)
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
|
||||
# Confidence and consistency averages
|
||||
confidences = [r.get("confidence", 0) for r in results]
|
||||
consistencies = [r.get("consistency", 0) for r in results]
|
||||
avg_confidence = np.mean(confidences) if confidences else 0
|
||||
avg_consistency = np.mean(consistencies) if consistencies else 0
|
||||
|
||||
# Per-stage accuracy
|
||||
stage_correct = {}
|
||||
stage_total = {}
|
||||
for r in results:
|
||||
gt = r.get("ground_truth", "UNKNOWN")
|
||||
stage_total[gt] = stage_total.get(gt, 0) + 1
|
||||
if r.get("is_correct", False):
|
||||
stage_correct[gt] = stage_correct.get(gt, 0) + 1
|
||||
|
||||
stage_accuracy = {}
|
||||
for stage in stages:
|
||||
if stage in stage_total:
|
||||
stage_accuracy[stage] = stage_correct.get(stage, 0) / stage_total[stage]
|
||||
else:
|
||||
stage_accuracy[stage] = 0.0
|
||||
|
||||
# F1 Score and Macro metrics
|
||||
predictions = [r.get("answer", "") for r in results]
|
||||
ground_truths = [r.get("ground_truth", "") for r in results]
|
||||
|
||||
try:
|
||||
macro_f1 = f1_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_precision = precision_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
macro_recall = recall_score(ground_truths, predictions, average='macro', zero_division=0)
|
||||
except Exception:
|
||||
macro_f1 = macro_precision = macro_recall = 0.0
|
||||
|
||||
# Temporal analysis
|
||||
mid_point = len(results) // 2
|
||||
if mid_point > 0:
|
||||
first_half = results[:mid_point]
|
||||
second_half = results[mid_point:]
|
||||
first_half_acc = sum(1 for r in first_half if r.get("is_correct", False)) / len(first_half)
|
||||
second_half_acc = sum(1 for r in second_half if r.get("is_correct", False)) / len(second_half)
|
||||
improvement = second_half_acc - first_half_acc
|
||||
else:
|
||||
first_half_acc = second_half_acc = accuracy
|
||||
improvement = 0
|
||||
|
||||
# Learning curve (every 10 samples)
|
||||
learning_curve = []
|
||||
for i in range(0, len(results), 10):
|
||||
chunk = results[:i+10]
|
||||
chunk_acc = sum(1 for r in chunk if r.get("is_correct", False)) / len(chunk)
|
||||
learning_curve.append({
|
||||
"sample_idx": min(i+10, len(results)),
|
||||
"cumulative_accuracy": chunk_acc,
|
||||
})
|
||||
|
||||
# Convergence speed (90% of final accuracy)
|
||||
final_accuracy = accuracy
|
||||
convergence_threshold = 0.9 * final_accuracy
|
||||
convergence_idx = None
|
||||
running_correct = 0
|
||||
for i, r in enumerate(results):
|
||||
running_correct += 1 if r.get("is_correct", False) else 0
|
||||
running_acc = running_correct / (i + 1)
|
||||
if running_acc >= convergence_threshold:
|
||||
convergence_idx = i
|
||||
break
|
||||
|
||||
# Queue statistics
|
||||
queue_stats = {}
|
||||
for r in results:
|
||||
if "queue_survival_stats" in r:
|
||||
queue_stats = r["queue_survival_stats"]
|
||||
break
|
||||
|
||||
return {
|
||||
"experiment_type": "consistency",
|
||||
"user_id": results[0].get("user_id") if results else None,
|
||||
"shuffle_seed": results[0].get("shuffle_seed") if results else None,
|
||||
"total_samples": total,
|
||||
"correct": correct,
|
||||
"accuracy": accuracy,
|
||||
"avg_confidence": float(avg_confidence),
|
||||
"avg_consistency": float(avg_consistency),
|
||||
"stage_accuracy": stage_accuracy,
|
||||
"stage_sample_counts": stage_total,
|
||||
"macro_f1": float(macro_f1),
|
||||
"macro_precision": float(macro_precision),
|
||||
"macro_recall": float(macro_recall),
|
||||
"temporal_analysis": {
|
||||
"first_half_accuracy": first_half_acc,
|
||||
"second_half_accuracy": second_half_acc,
|
||||
"improvement": improvement,
|
||||
},
|
||||
"learning_curve": learning_curve,
|
||||
"convergence_idx": convergence_idx,
|
||||
"queue_stats": queue_stats,
|
||||
}
|
||||
|
||||
|
||||
def save_results(
|
||||
results: List[Dict[str, Any]],
|
||||
stats: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
user_id: int,
|
||||
shuffle_seed: int,
|
||||
) -> None:
|
||||
"""Save experiment results and statistics."""
|
||||
log_path = config["log_path"]
|
||||
|
||||
# Create user/seed specific directory
|
||||
output_dir = os.path.join(log_path, f"user{user_id:02d}_seed{shuffle_seed}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save statistics
|
||||
stats_path = os.path.join(output_dir, "statistics.json")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||||
log(f"[SAVE] Statistics: {stats_path}")
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(output_dir, "results.json")
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
log(f"[SAVE] Results: {results_path}")
|
||||
|
||||
# Save config
|
||||
config_path = os.path.join(output_dir, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
log(f"[SAVE] Config: {config_path}")
|
||||
|
||||
|
||||
def main(
|
||||
config_path: str = "sc/config/experiment_consistency.yaml",
|
||||
user_id: int = 5,
|
||||
shuffle_seed: int = 42,
|
||||
) -> None:
|
||||
"""
|
||||
Run Consistency-based Queue Policy experiment.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
user_id: User ID to process (5 or 10)
|
||||
shuffle_seed: Shuffle seed for data order (42, 123, or 456)
|
||||
|
||||
Example:
|
||||
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
|
||||
"""
|
||||
log(f"[MAIN] Loading config: {config_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
# Override with CLI arguments
|
||||
config["user_id"] = user_id
|
||||
config["shuffle_seed"] = shuffle_seed
|
||||
|
||||
# Create unique log path
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
config["log_path"] = f"{config['log_path']}/user{user_id:02d}_seed{shuffle_seed}_{timestamp}"
|
||||
os.makedirs(config["log_path"], exist_ok=True)
|
||||
|
||||
# Print experiment info
|
||||
debug_log.log_policy_main_header(
|
||||
"CONSISTENCY-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
config.get("queue_size", 5),
|
||||
len(config.get("models", [])),
|
||||
config["log_path"],
|
||||
)
|
||||
|
||||
# Load shuffled data
|
||||
debug_log.log_policy_loading_data(user_id, shuffle_seed)
|
||||
dataloader = ShuffledDataLoader(
|
||||
data_path=config["data_path"],
|
||||
user_id=user_id,
|
||||
seed=shuffle_seed,
|
||||
example_pool=config.get("example_pool", "out"),
|
||||
)
|
||||
|
||||
# Load models
|
||||
debug_log.log_policy_loading_models()
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
num_ctx=config.get("num_ctx", 15000),
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
debug_log.log_policy_start(label="experiment")
|
||||
results = asyncio.run(run_consistency_experiment(
|
||||
dataloader=dataloader,
|
||||
model_pool=model_pool,
|
||||
config=config,
|
||||
user_id=user_id,
|
||||
shuffle_seed=shuffle_seed,
|
||||
))
|
||||
|
||||
# Compute statistics
|
||||
stages = config.get("stages", ["W", "N1", "N2", "N3", "REM"])
|
||||
stats = compute_statistics(results, stages)
|
||||
|
||||
# Print final summary
|
||||
temporal = stats.get("temporal_analysis", {})
|
||||
debug_log.log_policy_complete_summary(
|
||||
"CONSISTENCY",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
stats,
|
||||
stats.get("stage_accuracy", {}),
|
||||
stats.get("stage_sample_counts", {}),
|
||||
temporal,
|
||||
)
|
||||
|
||||
# Save results
|
||||
save_results(results, stats, config, user_id, shuffle_seed)
|
||||
log(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(main)
|
||||
@@ -1,124 +0,0 @@
|
||||
"""
|
||||
Call transformers serve (chat) and our HF API (logprobs) from Python with requests.
|
||||
|
||||
Summary:
|
||||
--------
|
||||
- transformers serve (transformers chat localhost:8000 --model-name-or-path ...):
|
||||
- Exposes OpenAI-compatible endpoints: /v1/chat/completions, /v1/responses, /v1/models.
|
||||
- It does NOT return logits or logprobs; the response chunks have "logprobs": null.
|
||||
- You can still use it for chat from Python via requests (see chat_with_transformers_serve).
|
||||
|
||||
- For answer + logprobs (top-k per token):
|
||||
- Use our custom API in sc/hf_api.py:
|
||||
MODEL_DIR=/path/to/model uvicorn sc.hf_api:app --host 0.0.0.0 --port 8000
|
||||
- Then call POST /generate with requests (see get_logprobs_via_hf_api).
|
||||
|
||||
Usage:
|
||||
------
|
||||
# Chat only (transformers serve on 8000):
|
||||
python -c "
|
||||
from sc.transformers_serve_client_example import chat_with_transformers_serve
|
||||
print(chat_with_transformers_serve('Hello!'))
|
||||
"
|
||||
|
||||
# Answer + logprobs (sc/hf_api on 8000):
|
||||
python -c "
|
||||
from sc.transformers_serve_client_example import get_logprobs_via_hf_api
|
||||
r = get_logprobs_via_hf_api([{'role':'user','content':'Hello!'}])
|
||||
print('text:', r['generated_text'])
|
||||
print('first step top-k:', r['steps'][0]['topk'])
|
||||
"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
# Default base URL for transformers serve (OpenAI-compatible)
|
||||
TRANSFORMERS_SERVE_URL = "http://localhost:8000/v1"
|
||||
# Default base URL for our sc/hf_api (generate + logprobs)
|
||||
HF_API_URL = "http://localhost:8000"
|
||||
|
||||
|
||||
def chat_with_transformers_serve(
|
||||
user_message: str,
|
||||
*,
|
||||
base_url: str = TRANSFORMERS_SERVE_URL,
|
||||
model: str = "openai/gpt-oss-20b",
|
||||
max_tokens: int = 256,
|
||||
stream: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Send a chat message to a server running `transformers serve`.
|
||||
Returns the assistant reply text. No logits/logprobs (server does not provide them).
|
||||
"""
|
||||
url = f"{base_url.rstrip('/')}/chat/completions"
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": user_message}],
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
resp = requests.post(url, json=payload, timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Non-stream: choices[0].message.content
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
return ""
|
||||
msg = choices[0].get("message", {})
|
||||
return msg.get("content", "") or ""
|
||||
|
||||
|
||||
def get_logprobs_via_hf_api(
|
||||
messages: List[Dict[str, str]],
|
||||
*,
|
||||
base_url: str = HF_API_URL,
|
||||
max_new_tokens: int = 64,
|
||||
top_k: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Call our sc/hf_api POST /generate endpoint.
|
||||
Returns generated text and per-token top-k logprobs (no raw logits over the wire).
|
||||
"""
|
||||
url = f"{base_url.rstrip('/')}/generate"
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"top_k": top_k,
|
||||
"temperature": 0.0,
|
||||
"do_sample": False,
|
||||
}
|
||||
resp = requests.post(url, json=payload, timeout=120)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python -m sc.transformers_serve_client_example chat|logprobs [message]")
|
||||
print(" chat -> call transformers serve /v1/chat/completions (no logprobs)")
|
||||
print(" logprobs -> call sc/hf_api /generate (returns top-k logprobs)")
|
||||
sys.exit(0)
|
||||
|
||||
cmd = sys.argv[1].lower()
|
||||
message = (sys.argv[2] if len(sys.argv) > 2 else "Hello, how are you?").strip()
|
||||
|
||||
if cmd == "chat":
|
||||
text = chat_with_transformers_serve(message)
|
||||
print("Reply:", text)
|
||||
elif cmd == "logprobs":
|
||||
out = get_logprobs_via_hf_api([{"role": "user", "content": message}])
|
||||
print("Generated:", out.get("generated_text", ""))
|
||||
print("Steps (first 3):")
|
||||
for s in out.get("steps", [])[:3]:
|
||||
print(" token:", repr(s.get("token")), "logprob:", s.get("logprob"), "topk:", [t.get("token") for t in s.get("topk", [])[:5]])
|
||||
else:
|
||||
print("Unknown command. Use 'chat' or 'logprobs'.")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user