260119:chronos_2_based_embedding_running

This commit is contained in:
ssum21
2026-01-19 19:56:35 +09:00
parent 50e06cca6a
commit b0e656038b
4 changed files with 373 additions and 20 deletions

View File

@@ -37,8 +37,9 @@ Date: 2026-01-09
""" """
import os import os
import sys
from glob import glob from glob import glob
from typing import Dict, Any, List, Tuple from typing import Dict, Any, List, Tuple, Optional
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -48,6 +49,10 @@ from chronos import BaseChronosPipeline, Chronos2Pipeline
from fire import Fire from fire import Fire
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
# Add parent directories to path for importing metadata loader
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../.."))
from preprocess.sleepedf_metadata import SleepEDFMetadata, DEFAULT_METADATA_PATH
# pipeline: Chronos2Pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cuda") # pipeline: Chronos2Pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cuda")
# ============================================================================= # =============================================================================
@@ -234,7 +239,8 @@ class Chronos_2_Embedder:
def extract_embeddings( def extract_embeddings(
self, self,
data_root: str, data_root: str,
batch_size: int = 32 batch_size: int = 32,
metadata_path: str = DEFAULT_METADATA_PATH,
) -> Dataset: ) -> Dataset:
""" """
Extract embeddings from all sessions under the data root directory. Extract embeddings from all sessions under the data root directory.
@@ -247,20 +253,34 @@ class Chronos_2_Embedder:
batch_size: Number of samples to process together. batch_size: Number of samples to process together.
Larger = faster but more memory. Larger = faster but more memory.
32 is a good balance for most GPUs. 32 is a good balance for most GPUs.
metadata_path: Path to SC-subjects.xls for gender/age info
Returns: Returns:
HuggingFace Dataset with columns: HuggingFace Dataset with columns:
- user_id, session_id, idx, label (metadata) - user_id, session_id, idx, label (metadata)
- gender, age (demographic metadata)
- embedding (vector; dim depends on variate_fusion: 1024 for concat, 512 for mean) - embedding (vector; dim depends on variate_fusion: 1024 for concat, 512 for mean)
""" """
session_paths = self.discover_session_paths(data_root) session_paths = self.discover_session_paths(data_root)
print(f"[INFO] Discovered {len(session_paths)} sessions") print(f"[INFO] Discovered {len(session_paths)} sessions")
# Load metadata for gender/age information
try:
metadata = SleepEDFMetadata(metadata_path)
has_metadata = True
except FileNotFoundError:
print(f"[WARNING] Metadata file not found: {metadata_path}")
print(f"[WARNING] Gender/age information will not be available")
metadata = None
has_metadata = False
all_embeddings = [] all_embeddings = []
all_user_ids = [] all_user_ids = []
all_session_ids = [] all_session_ids = []
all_idxs = [] all_idxs = []
all_labels = [] all_labels = []
all_genders = []
all_ages = []
for user_id, session_id, session_path in session_paths: for user_id, session_id, session_path in session_paths:
# Load HuggingFace dataset from disk # Load HuggingFace dataset from disk
@@ -282,11 +302,25 @@ class Chronos_2_Embedder:
# Collect embeddings and metadata # Collect embeddings and metadata
for i in range(embeddings.shape[0]): for i in range(embeddings.shape[0]):
user_id_str = str(batch["user_id"][i])
all_embeddings.append(embeddings[i].tolist()) all_embeddings.append(embeddings[i].tolist())
all_user_ids.append(str(batch["user_id"][i])) all_user_ids.append(user_id_str)
all_session_ids.append(str(batch["session_id"][i])) all_session_ids.append(str(batch["session_id"][i]))
all_idxs.append(int(batch["idx"][i])) all_idxs.append(int(batch["idx"][i]))
all_labels.append(str(batch["label"][i])) all_labels.append(str(batch["label"][i]))
# Add demographic metadata
if has_metadata:
info = metadata.get_info(user_id_str)
if info:
all_genders.append(info['gender'])
all_ages.append(info['age'])
else:
all_genders.append('Unknown')
all_ages.append(-1)
else:
all_genders.append('Unknown')
all_ages.append(-1)
# Create HuggingFace Dataset # Create HuggingFace Dataset
result_dataset = Dataset.from_dict({ result_dataset = Dataset.from_dict({
@@ -294,10 +328,17 @@ class Chronos_2_Embedder:
"session_id": all_session_ids, "session_id": all_session_ids,
"idx": all_idxs, "idx": all_idxs,
"label": all_labels, "label": all_labels,
"gender": all_genders,
"age": all_ages,
"embedding": all_embeddings, "embedding": all_embeddings,
}) })
print(f"[INFO] Total samples: {len(result_dataset)}") print(f"[INFO] Total samples: {len(result_dataset)}")
if has_metadata:
gender_counts = {}
for g in all_genders:
gender_counts[g] = gender_counts.get(g, 0) + 1
print(f"[INFO] Gender distribution in extracted data: {gender_counts}")
return result_dataset return result_dataset
def save_embeddings( def save_embeddings(
@@ -460,6 +501,8 @@ class CLI:
users: str = None, users: str = None,
num_users: int = 0, num_users: int = 0,
labels: str = None, labels: str = None,
gender: str = None,
metadata_path: str = DEFAULT_METADATA_PATH,
) -> None: ) -> None:
""" """
Visualize embeddings with t-SNE. Visualize embeddings with t-SNE.
@@ -470,7 +513,9 @@ class CLI:
perplexity: t-SNE perplexity parameter (default: 30.0) perplexity: t-SNE perplexity parameter (default: 30.0)
users: Comma-separated user IDs to include (e.g., '00,01,02') users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0) num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., '0,1,2') labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
gender: Filter by gender ('M', 'F', or None for all)
metadata_path: Path to metadata file for gender lookup (if not in dataset)
""" """
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
@@ -489,6 +534,11 @@ class CLI:
dataset = dataset.filter(lambda x: x["user_id"] in selected_users) dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}") print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by gender
if gender:
dataset = dataset.filter(lambda x: x.get("gender", "Unknown") == gender)
print(f"[INFO] Filtered to gender: {gender}")
# Filter by sleep stage labels # Filter by sleep stage labels
if labels: if labels:
label_list = [l.strip() for l in labels.split(",")] label_list = [l.strip() for l in labels.split(",")]
@@ -497,6 +547,13 @@ class CLI:
print(f"[INFO] Total samples: {len(dataset)}") print(f"[INFO] Total samples: {len(dataset)}")
# Print gender distribution
if "gender" in dataset.column_names:
gender_counts = {}
for g in dataset["gender"]:
gender_counts[g] = gender_counts.get(g, 0) + 1
print(f"[INFO] Gender distribution: {gender_counts}")
# Extract embeddings as numpy array for t-SNE # Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"]) embeddings = np.array(dataset["embedding"])

View File

@@ -1,5 +1,13 @@
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF" data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
num_seeds: 10 num_seeds: 10
num_examples: 1
# Selection criteria: out_random | in_random | out_similar
selection_criteria: "out_random"
# Embedding path (not required for random criteria)
# embedding_path: "/home/ssum/tsllm_personalization_icl/embeddings_20users"
models: models:
- ollama:url:joy.kaist.ac.kr:11437/gpt-oss:20b - ollama:url:joy.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:joy.kaist.ac.kr:11438/gpt-oss:20b - ollama:url:joy.kaist.ac.kr:11438/gpt-oss:20b
@@ -17,4 +25,8 @@ models:
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b - ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b - ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b - ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
log_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF_out_random"
# chronos_base : "/mnt/sting/ssum/sleepedf_chronos_base_result"
# out_random : "/mnt/sting/ssum/sleepedf_chronos_result_outrandom"
# in_random : "/mnt/sting/ssum/sleepedf_chronos_result"
log_path: "/mnt/sting/ssum/sleepedf_chronos_result"

View File

@@ -3,17 +3,35 @@ import json
import datasets import datasets
import numpy as np import numpy as np
from glob import glob from glob import glob
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .embedding_index import EmbeddingIndex
class DataLoader: class DataLoader:
def __init__(self, data_path, user_id, selection_criteria="out_random", num_examples=1): def __init__(
self,
data_path,
user_id,
selection_criteria="out_random",
num_examples=1,
embedding_index: Optional["EmbeddingIndex"] = None,
):
self.is_valid = False self.is_valid = False
self.embedding_index = embedding_index
self.data_path = data_path
if not os.path.exists(os.path.join(data_path, "info.json")): if not os.path.exists(os.path.join(data_path, "info.json")):
return return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")): if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
return return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")): if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
return return
if selection_criteria in ["out_similar", "in_similar"] and embedding_index is None:
print(f"[WARNING] {selection_criteria} requires embedding_index, falling back to random")
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8")) self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2")) self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
@@ -34,9 +52,19 @@ class DataLoader:
self.num_examples = num_examples self.num_examples = num_examples
self.classes = sorted(list(self.metadata["class"].keys())) self.classes = sorted(list(self.metadata["class"].keys()))
self.selected_examples = self.sample_examples()
if self.selected_examples is None: # Build lookup index for fast example retrieval: (user_id, idx) -> dataset_index
return self._example_lookup = {}
for i, example in enumerate(self.example_dataset):
key = (str(example["user_id"]), int(example["idx"]))
self._example_lookup[key] = i
if selection_criteria in ["out_similar", "in_similar"]:
self.selected_examples = None
else:
self.selected_examples = self.sample_examples()
if self.selected_examples is None:
return
self.is_valid = True self.is_valid = True
@@ -44,11 +72,95 @@ class DataLoader:
return len(self.test_dataset) return len(self.test_dataset)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.test_dataset[idx], self.selected_examples sample = self.test_dataset[idx]
if self.selection_criteria in ["out_similar", "in_similar"]:
examples = self.sample_similar_examples(sample)
else:
examples = self.selected_examples
return sample, examples
def __iter__(self): def __iter__(self):
for sample in self.test_dataset: for sample in self.test_dataset:
yield sample, self.selected_examples if self.selection_criteria in ["out_similar", "in_similar"]:
examples = self.sample_similar_examples(sample)
else:
examples = self.selected_examples
yield sample, examples
def sample_similar_examples(self, sample):
"""
Sample examples based on embedding similarity to the given sample.
For each class, finds the most similar example from the example dataset
using Chronos-2 embeddings.
Args:
sample: The test sample to find similar examples for
Returns:
HuggingFace Dataset containing similar examples (one per class)
"""
if self.embedding_index is None:
return self.sample_examples()
query_embedding = self.embedding_index.get_embedding_by_key(
user_id=str(sample["user_id"]),
session_id=str(sample["session_id"]),
idx=int(sample["idx"]),
)
if query_embedding is None:
print(f"[WARNING] No embedding found for sample {sample['user_id']}/{sample['session_id']}/{sample['idx']}")
return self.sample_examples()
if self.selection_criteria == "out_similar":
exclude_user = str(self.user_id)
include_user = None
else:
exclude_user = None
include_user = str(self.user_id)
similar_per_class = self.embedding_index.find_similar_per_class(
query_embedding=query_embedding,
classes=self.classes,
k_per_class=self.num_examples,
exclude_user=exclude_user,
include_user=include_user,
filter_session="1",
)
example_list = []
for cls, similar_samples in similar_per_class.items():
for global_idx, similarity, metadata in similar_samples:
example = self._find_example_by_metadata(metadata)
if example is not None:
example_list.append(example)
if len(example_list) == 0:
print(f"[WARNING] No similar examples found, falling back to random")
return self.sample_examples()
return datasets.Dataset.from_list(example_list)
def _find_example_by_metadata(self, metadata):
"""Find an example in the example_dataset by its metadata using O(1) lookup."""
user_id = str(metadata["user_id"])
idx = int(metadata["idx"])
key = (user_id, idx)
if key not in self._example_lookup:
return None
dataset_idx = self._example_lookup[key]
example = self.example_dataset[dataset_idx]
return {
"user_id": example["user_id"],
"session_id": example["session_id"],
"idx": example["idx"],
"label": example["label"],
"features": example["features"],
"data": example.get("data", {}),
}
def sample_examples(self): def sample_examples(self):
example_dataset = datasets.Dataset.from_list([]) example_dataset = datasets.Dataset.from_list([])

190
run.py
View File

@@ -5,7 +5,7 @@ import yaml
import json import json
import numpy as np import numpy as np
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any from typing import List, Dict, Any, Optional
import time import time
from glob import glob from glob import glob
@@ -14,11 +14,47 @@ from fire import Fire
from core.model import load_models from core.model import load_models
from core.data_loader import DataLoader from core.data_loader import DataLoader
from core.sensing_agent import SensingAgent from core.sensing_agent import SensingAgent
from core.embedding_index import EmbeddingIndex, create_embedding_index
_EMBEDDING_INDEX: Optional[EmbeddingIndex] = None
def init_embedding_index(embedding_path: Optional[str]) -> Optional[EmbeddingIndex]:
"""Initialize global embedding index for similarity-based selection."""
global _EMBEDDING_INDEX
if embedding_path is None:
_EMBEDDING_INDEX = None
return None
if _EMBEDDING_INDEX is not None:
return _EMBEDDING_INDEX
_EMBEDDING_INDEX = create_embedding_index(embedding_path)
return _EMBEDDING_INDEX
def load_user_data_sync(data_path, user, seed, log_path_base): def load_user_data_sync(
print(f"[DATA LOADING] Starting data loading for user: {user}, seed: {seed}") data_path: str,
data_loader = DataLoader(data_path, user, selection_criteria="out_random", num_examples=1) user: str,
seed: int,
log_path_base: str,
selection_criteria: str = "out_random",
num_examples: int = 1,
embedding_index: Optional[EmbeddingIndex] = None,
) -> List[Dict[str, Any]]:
"""Load data for a single user with specified selection criteria."""
print(f"[DATA LOADING] Starting: user={user}, seed={seed}, criteria={selection_criteria}")
# Set random seed for reproducibility
np.random.seed(seed)
data_loader = DataLoader(
data_path,
user,
selection_criteria=selection_criteria,
num_examples=num_examples,
embedding_index=embedding_index,
)
if not data_loader.is_valid: if not data_loader.is_valid:
print(f"[DATA LOADING] Skipping invalid user: {user}") print(f"[DATA LOADING] Skipping invalid user: {user}")
return [] return []
@@ -52,13 +88,36 @@ def load_user_data_sync(data_path, user, seed, log_path_base):
return tasks return tasks
async def load_data_parallel(config): async def load_data_parallel(config: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Load data for all users in parallel with configurable selection criteria."""
print("[DATA LOADING] Starting parallel data loading...") print("[DATA LOADING] Starting parallel data loading...")
# Get configuration parameters
selection_criteria = config.get("selection_criteria", "out_random")
embedding_path = config.get("embedding_path", None)
num_examples = config.get("num_examples", 1)
# Initialize embedding index if needed for similarity-based selection
embedding_index = None
if selection_criteria in ["out_similar", "in_similar"]:
if embedding_path is None:
print(f"[WARNING] {selection_criteria} requires embedding_path in config")
print("[WARNING] Falling back to random selection")
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
else:
embedding_index = init_embedding_index(embedding_path)
if embedding_index is None:
print(f"[WARNING] Failed to load embeddings, falling back to random")
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
user_paths = glob(os.path.join(config["data_path"], "*")) user_paths = glob(os.path.join(config["data_path"], "*"))
user_paths = [path for path in user_paths if os.path.isdir(path)] user_paths = [path for path in user_paths if os.path.isdir(path)]
users = [path.split("/")[-1] for path in user_paths] users = [path.split("/")[-1] for path in user_paths]
print(f"[DATA LOADING] Found {len(users)} users: {users}") print(f"[DATA LOADING] Found {len(users)} users: {users}")
print(f"[DATA LOADING] Selection criteria: {selection_criteria}")
print(f"[DATA LOADING] Num examples per class: {num_examples}")
max_workers = config.get("data_workers", 96) max_workers = config.get("data_workers", 96)
print(f"[DATA LOADING] Using {max_workers} workers for data loading") print(f"[DATA LOADING] Using {max_workers} workers for data loading")
@@ -74,7 +133,10 @@ async def load_data_parallel(config):
config["data_path"], config["data_path"],
user, user,
seed, seed,
config["log_path"] config["log_path"],
selection_criteria,
num_examples,
embedding_index,
) )
futures.append(future) futures.append(future)
@@ -89,10 +151,14 @@ async def load_data_parallel(config):
else: else:
all_tasks.extend(result) all_tasks.extend(result)
print(f"[DATA LOADING] Total tasks: {len(all_tasks)}")
return all_tasks return all_tasks
async def run_parallel(kwargs_list, model_pool, config): async def run_parallel(kwargs_list: List[Dict[str, Any]], model_pool, config: Dict[str, Any]) -> None:
"""Run classification tasks in parallel using the model pool."""
print(f"[EXECUTION] Starting {len(kwargs_list)} classification tasks...")
tasks = [] tasks = []
for kwargs in kwargs_list: for kwargs in kwargs_list:
agent = SensingAgent( agent = SensingAgent(
@@ -108,15 +174,121 @@ async def run_parallel(kwargs_list, model_pool, config):
task = asyncio.create_task(agent.solve(kwargs["sample"], kwargs["examples"], kwargs["ground_truth"])) task = asyncio.create_task(agent.solve(kwargs["sample"], kwargs["examples"], kwargs["ground_truth"]))
tasks.append(task) tasks.append(task)
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
print("[EXECUTION] All tasks completed")
def run(config_path): def run(config_path: str) -> None:
"""
Main entry point for running experiments.
Args:
config_path: Path to YAML configuration file
"""
print(f"[MAIN] Loading config from: {config_path}")
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader) config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
print("=" * 60)
print("EXPERIMENT CONFIGURATION")
print("=" * 60)
print(f" Data path: {config.get('data_path', 'N/A')}")
print(f" Log path: {config.get('log_path', 'N/A')}")
print(f" Selection criteria: {config.get('selection_criteria', 'out_random')}")
print(f" Num examples: {config.get('num_examples', 1)}")
print(f" Num seeds: {config.get('num_seeds', 1)}")
print(f" Embedding path: {config.get('embedding_path', 'N/A')}")
print(f" Num models: {len(config.get('models', []))}")
print("=" * 60)
model_pool = load_models(config["models"]) model_pool = load_models(config["models"])
kwargs_list = asyncio.run(load_data_parallel(config)) kwargs_list = asyncio.run(load_data_parallel(config))
if len(kwargs_list) == 0:
print("[ERROR] No valid tasks to run. Check data paths and configuration.")
return
# Warmup models
print("[MAIN] Warming up models...")
model_pool.warmup() model_pool.warmup()
# Run experiments
print("[MAIN] Starting experiments...")
start_time = time.time()
asyncio.run(run_parallel(kwargs_list, model_pool, config)) asyncio.run(run_parallel(kwargs_list, model_pool, config))
elapsed = time.time() - start_time
print(f"[MAIN] Experiment completed in {elapsed:.2f} seconds")
print(f"[MAIN] Results saved to: {config['log_path']}")
def run_comparison(
base_config_path: str,
criteria_list: str = "out_random,in_random,out_similar,in_similar",
embedding_path: str = None,
) -> None:
"""
Run experiments comparing multiple selection criteria.
Args:
base_config_path: Path to base YAML config file
criteria_list: Comma-separated list of selection criteria to compare
embedding_path: Path to embeddings (required for *_similar criteria)
Example:
python run.py compare config/sleepedf.yaml \\
--criteria_list="out_random,out_similar" \\
--embedding_path="./embeddings_full"
"""
base_config = yaml.load(open(base_config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
criteria = [c.strip() for c in criteria_list.split(",")]
print("=" * 60)
print("COMPARISON EXPERIMENT")
print("=" * 60)
print(f" Selection criteria to compare: {criteria}")
print(f" Embedding path: {embedding_path}")
print("=" * 60)
for criterion in criteria:
print(f"\n{'='*60}")
print(f"Running experiment: {criterion}")
print(f"{'='*60}")
config = base_config.copy()
config["selection_criteria"] = criterion
base_log_path = config["log_path"]
if not base_log_path.endswith(criterion):
config["log_path"] = f"{base_log_path}_{criterion}"
if criterion in ["out_similar", "in_similar"]:
if embedding_path:
config["embedding_path"] = embedding_path
else:
print(f"[WARNING] {criterion} requires --embedding_path argument")
continue
model_pool = load_models(config["models"])
kwargs_list = asyncio.run(load_data_parallel(config))
if len(kwargs_list) == 0:
print(f"[WARNING] No tasks for {criterion}, skipping")
continue
model_pool.warmup()
asyncio.run(run_parallel(kwargs_list, model_pool, config))
print(f"[DONE] Results for {criterion} saved to: {config['log_path']}")
print("\n" + "=" * 60)
print("ALL COMPARISON EXPERIMENTS COMPLETED")
print("=" * 60)
if __name__ == "__main__": if __name__ == "__main__":
Fire(run) Fire({
"run": run,
"compare": run_comparison,
})