260119:chronos_2_based_embedding_running
This commit is contained in:
@@ -37,8 +37,9 @@ Date: 2026-01-09
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from typing import Dict, Any, List, Tuple
|
from typing import Dict, Any, List, Tuple, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@@ -48,6 +49,10 @@ from chronos import BaseChronosPipeline, Chronos2Pipeline
|
|||||||
from fire import Fire
|
from fire import Fire
|
||||||
from sklearn.manifold import TSNE
|
from sklearn.manifold import TSNE
|
||||||
|
|
||||||
|
# Add parent directories to path for importing metadata loader
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
|
from preprocess.sleepedf_metadata import SleepEDFMetadata, DEFAULT_METADATA_PATH
|
||||||
|
|
||||||
# pipeline: Chronos2Pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cuda")
|
# pipeline: Chronos2Pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cuda")
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -234,7 +239,8 @@ class Chronos_2_Embedder:
|
|||||||
def extract_embeddings(
|
def extract_embeddings(
|
||||||
self,
|
self,
|
||||||
data_root: str,
|
data_root: str,
|
||||||
batch_size: int = 32
|
batch_size: int = 32,
|
||||||
|
metadata_path: str = DEFAULT_METADATA_PATH,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Extract embeddings from all sessions under the data root directory.
|
Extract embeddings from all sessions under the data root directory.
|
||||||
@@ -247,20 +253,34 @@ class Chronos_2_Embedder:
|
|||||||
batch_size: Number of samples to process together.
|
batch_size: Number of samples to process together.
|
||||||
Larger = faster but more memory.
|
Larger = faster but more memory.
|
||||||
32 is a good balance for most GPUs.
|
32 is a good balance for most GPUs.
|
||||||
|
metadata_path: Path to SC-subjects.xls for gender/age info
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
HuggingFace Dataset with columns:
|
HuggingFace Dataset with columns:
|
||||||
- user_id, session_id, idx, label (metadata)
|
- user_id, session_id, idx, label (metadata)
|
||||||
|
- gender, age (demographic metadata)
|
||||||
- embedding (vector; dim depends on variate_fusion: 1024 for concat, 512 for mean)
|
- embedding (vector; dim depends on variate_fusion: 1024 for concat, 512 for mean)
|
||||||
"""
|
"""
|
||||||
session_paths = self.discover_session_paths(data_root)
|
session_paths = self.discover_session_paths(data_root)
|
||||||
print(f"[INFO] Discovered {len(session_paths)} sessions")
|
print(f"[INFO] Discovered {len(session_paths)} sessions")
|
||||||
|
|
||||||
|
# Load metadata for gender/age information
|
||||||
|
try:
|
||||||
|
metadata = SleepEDFMetadata(metadata_path)
|
||||||
|
has_metadata = True
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"[WARNING] Metadata file not found: {metadata_path}")
|
||||||
|
print(f"[WARNING] Gender/age information will not be available")
|
||||||
|
metadata = None
|
||||||
|
has_metadata = False
|
||||||
|
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
all_user_ids = []
|
all_user_ids = []
|
||||||
all_session_ids = []
|
all_session_ids = []
|
||||||
all_idxs = []
|
all_idxs = []
|
||||||
all_labels = []
|
all_labels = []
|
||||||
|
all_genders = []
|
||||||
|
all_ages = []
|
||||||
|
|
||||||
for user_id, session_id, session_path in session_paths:
|
for user_id, session_id, session_path in session_paths:
|
||||||
# Load HuggingFace dataset from disk
|
# Load HuggingFace dataset from disk
|
||||||
@@ -282,11 +302,25 @@ class Chronos_2_Embedder:
|
|||||||
|
|
||||||
# Collect embeddings and metadata
|
# Collect embeddings and metadata
|
||||||
for i in range(embeddings.shape[0]):
|
for i in range(embeddings.shape[0]):
|
||||||
|
user_id_str = str(batch["user_id"][i])
|
||||||
all_embeddings.append(embeddings[i].tolist())
|
all_embeddings.append(embeddings[i].tolist())
|
||||||
all_user_ids.append(str(batch["user_id"][i]))
|
all_user_ids.append(user_id_str)
|
||||||
all_session_ids.append(str(batch["session_id"][i]))
|
all_session_ids.append(str(batch["session_id"][i]))
|
||||||
all_idxs.append(int(batch["idx"][i]))
|
all_idxs.append(int(batch["idx"][i]))
|
||||||
all_labels.append(str(batch["label"][i]))
|
all_labels.append(str(batch["label"][i]))
|
||||||
|
|
||||||
|
# Add demographic metadata
|
||||||
|
if has_metadata:
|
||||||
|
info = metadata.get_info(user_id_str)
|
||||||
|
if info:
|
||||||
|
all_genders.append(info['gender'])
|
||||||
|
all_ages.append(info['age'])
|
||||||
|
else:
|
||||||
|
all_genders.append('Unknown')
|
||||||
|
all_ages.append(-1)
|
||||||
|
else:
|
||||||
|
all_genders.append('Unknown')
|
||||||
|
all_ages.append(-1)
|
||||||
|
|
||||||
# Create HuggingFace Dataset
|
# Create HuggingFace Dataset
|
||||||
result_dataset = Dataset.from_dict({
|
result_dataset = Dataset.from_dict({
|
||||||
@@ -294,10 +328,17 @@ class Chronos_2_Embedder:
|
|||||||
"session_id": all_session_ids,
|
"session_id": all_session_ids,
|
||||||
"idx": all_idxs,
|
"idx": all_idxs,
|
||||||
"label": all_labels,
|
"label": all_labels,
|
||||||
|
"gender": all_genders,
|
||||||
|
"age": all_ages,
|
||||||
"embedding": all_embeddings,
|
"embedding": all_embeddings,
|
||||||
})
|
})
|
||||||
|
|
||||||
print(f"[INFO] Total samples: {len(result_dataset)}")
|
print(f"[INFO] Total samples: {len(result_dataset)}")
|
||||||
|
if has_metadata:
|
||||||
|
gender_counts = {}
|
||||||
|
for g in all_genders:
|
||||||
|
gender_counts[g] = gender_counts.get(g, 0) + 1
|
||||||
|
print(f"[INFO] Gender distribution in extracted data: {gender_counts}")
|
||||||
return result_dataset
|
return result_dataset
|
||||||
|
|
||||||
def save_embeddings(
|
def save_embeddings(
|
||||||
@@ -460,6 +501,8 @@ class CLI:
|
|||||||
users: str = None,
|
users: str = None,
|
||||||
num_users: int = 0,
|
num_users: int = 0,
|
||||||
labels: str = None,
|
labels: str = None,
|
||||||
|
gender: str = None,
|
||||||
|
metadata_path: str = DEFAULT_METADATA_PATH,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Visualize embeddings with t-SNE.
|
Visualize embeddings with t-SNE.
|
||||||
@@ -470,7 +513,9 @@ class CLI:
|
|||||||
perplexity: t-SNE perplexity parameter (default: 30.0)
|
perplexity: t-SNE perplexity parameter (default: 30.0)
|
||||||
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||||
num_users: Include only first N users, 0 = all (default: 0)
|
num_users: Include only first N users, 0 = all (default: 0)
|
||||||
labels: Comma-separated sleep stage labels to include (e.g., '0,1,2')
|
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||||
|
gender: Filter by gender ('M', 'F', or None for all)
|
||||||
|
metadata_path: Path to metadata file for gender lookup (if not in dataset)
|
||||||
"""
|
"""
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
@@ -489,6 +534,11 @@ class CLI:
|
|||||||
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||||
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||||
|
|
||||||
|
# Filter by gender
|
||||||
|
if gender:
|
||||||
|
dataset = dataset.filter(lambda x: x.get("gender", "Unknown") == gender)
|
||||||
|
print(f"[INFO] Filtered to gender: {gender}")
|
||||||
|
|
||||||
# Filter by sleep stage labels
|
# Filter by sleep stage labels
|
||||||
if labels:
|
if labels:
|
||||||
label_list = [l.strip() for l in labels.split(",")]
|
label_list = [l.strip() for l in labels.split(",")]
|
||||||
@@ -497,6 +547,13 @@ class CLI:
|
|||||||
|
|
||||||
print(f"[INFO] Total samples: {len(dataset)}")
|
print(f"[INFO] Total samples: {len(dataset)}")
|
||||||
|
|
||||||
|
# Print gender distribution
|
||||||
|
if "gender" in dataset.column_names:
|
||||||
|
gender_counts = {}
|
||||||
|
for g in dataset["gender"]:
|
||||||
|
gender_counts[g] = gender_counts.get(g, 0) + 1
|
||||||
|
print(f"[INFO] Gender distribution: {gender_counts}")
|
||||||
|
|
||||||
# Extract embeddings as numpy array for t-SNE
|
# Extract embeddings as numpy array for t-SNE
|
||||||
embeddings = np.array(dataset["embedding"])
|
embeddings = np.array(dataset["embedding"])
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,13 @@
|
|||||||
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||||
num_seeds: 10
|
num_seeds: 10
|
||||||
|
num_examples: 1
|
||||||
|
|
||||||
|
# Selection criteria: out_random | in_random | out_similar
|
||||||
|
selection_criteria: "out_random"
|
||||||
|
|
||||||
|
# Embedding path (not required for random criteria)
|
||||||
|
# embedding_path: "/home/ssum/tsllm_personalization_icl/embeddings_20users"
|
||||||
|
|
||||||
models:
|
models:
|
||||||
- ollama:url:joy.kaist.ac.kr:11437/gpt-oss:20b
|
- ollama:url:joy.kaist.ac.kr:11437/gpt-oss:20b
|
||||||
- ollama:url:joy.kaist.ac.kr:11438/gpt-oss:20b
|
- ollama:url:joy.kaist.ac.kr:11438/gpt-oss:20b
|
||||||
@@ -17,4 +25,8 @@ models:
|
|||||||
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
|
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
|
||||||
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
|
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
|
||||||
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
|
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
|
||||||
log_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF_out_random"
|
|
||||||
|
# chronos_base : "/mnt/sting/ssum/sleepedf_chronos_base_result"
|
||||||
|
# out_random : "/mnt/sting/ssum/sleepedf_chronos_result_outrandom"
|
||||||
|
# in_random : "/mnt/sting/ssum/sleepedf_chronos_result"
|
||||||
|
log_path: "/mnt/sting/ssum/sleepedf_chronos_result"
|
||||||
@@ -3,17 +3,35 @@ import json
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .embedding_index import EmbeddingIndex
|
||||||
|
|
||||||
|
|
||||||
class DataLoader:
|
class DataLoader:
|
||||||
def __init__(self, data_path, user_id, selection_criteria="out_random", num_examples=1):
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_path,
|
||||||
|
user_id,
|
||||||
|
selection_criteria="out_random",
|
||||||
|
num_examples=1,
|
||||||
|
embedding_index: Optional["EmbeddingIndex"] = None,
|
||||||
|
):
|
||||||
self.is_valid = False
|
self.is_valid = False
|
||||||
|
self.embedding_index = embedding_index
|
||||||
|
self.data_path = data_path
|
||||||
|
|
||||||
if not os.path.exists(os.path.join(data_path, "info.json")):
|
if not os.path.exists(os.path.join(data_path, "info.json")):
|
||||||
return
|
return
|
||||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
|
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
|
||||||
return
|
return
|
||||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
|
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if selection_criteria in ["out_similar", "in_similar"] and embedding_index is None:
|
||||||
|
print(f"[WARNING] {selection_criteria} requires embedding_index, falling back to random")
|
||||||
|
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
|
||||||
|
|
||||||
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
|
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
|
||||||
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
|
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
|
||||||
@@ -34,9 +52,19 @@ class DataLoader:
|
|||||||
self.num_examples = num_examples
|
self.num_examples = num_examples
|
||||||
|
|
||||||
self.classes = sorted(list(self.metadata["class"].keys()))
|
self.classes = sorted(list(self.metadata["class"].keys()))
|
||||||
self.selected_examples = self.sample_examples()
|
|
||||||
if self.selected_examples is None:
|
# Build lookup index for fast example retrieval: (user_id, idx) -> dataset_index
|
||||||
return
|
self._example_lookup = {}
|
||||||
|
for i, example in enumerate(self.example_dataset):
|
||||||
|
key = (str(example["user_id"]), int(example["idx"]))
|
||||||
|
self._example_lookup[key] = i
|
||||||
|
|
||||||
|
if selection_criteria in ["out_similar", "in_similar"]:
|
||||||
|
self.selected_examples = None
|
||||||
|
else:
|
||||||
|
self.selected_examples = self.sample_examples()
|
||||||
|
if self.selected_examples is None:
|
||||||
|
return
|
||||||
|
|
||||||
self.is_valid = True
|
self.is_valid = True
|
||||||
|
|
||||||
@@ -44,11 +72,95 @@ class DataLoader:
|
|||||||
return len(self.test_dataset)
|
return len(self.test_dataset)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.test_dataset[idx], self.selected_examples
|
sample = self.test_dataset[idx]
|
||||||
|
if self.selection_criteria in ["out_similar", "in_similar"]:
|
||||||
|
examples = self.sample_similar_examples(sample)
|
||||||
|
else:
|
||||||
|
examples = self.selected_examples
|
||||||
|
return sample, examples
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for sample in self.test_dataset:
|
for sample in self.test_dataset:
|
||||||
yield sample, self.selected_examples
|
if self.selection_criteria in ["out_similar", "in_similar"]:
|
||||||
|
examples = self.sample_similar_examples(sample)
|
||||||
|
else:
|
||||||
|
examples = self.selected_examples
|
||||||
|
yield sample, examples
|
||||||
|
|
||||||
|
def sample_similar_examples(self, sample):
|
||||||
|
"""
|
||||||
|
Sample examples based on embedding similarity to the given sample.
|
||||||
|
|
||||||
|
For each class, finds the most similar example from the example dataset
|
||||||
|
using Chronos-2 embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: The test sample to find similar examples for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HuggingFace Dataset containing similar examples (one per class)
|
||||||
|
"""
|
||||||
|
if self.embedding_index is None:
|
||||||
|
return self.sample_examples()
|
||||||
|
|
||||||
|
query_embedding = self.embedding_index.get_embedding_by_key(
|
||||||
|
user_id=str(sample["user_id"]),
|
||||||
|
session_id=str(sample["session_id"]),
|
||||||
|
idx=int(sample["idx"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if query_embedding is None:
|
||||||
|
print(f"[WARNING] No embedding found for sample {sample['user_id']}/{sample['session_id']}/{sample['idx']}")
|
||||||
|
return self.sample_examples()
|
||||||
|
|
||||||
|
if self.selection_criteria == "out_similar":
|
||||||
|
exclude_user = str(self.user_id)
|
||||||
|
include_user = None
|
||||||
|
else:
|
||||||
|
exclude_user = None
|
||||||
|
include_user = str(self.user_id)
|
||||||
|
|
||||||
|
similar_per_class = self.embedding_index.find_similar_per_class(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
classes=self.classes,
|
||||||
|
k_per_class=self.num_examples,
|
||||||
|
exclude_user=exclude_user,
|
||||||
|
include_user=include_user,
|
||||||
|
filter_session="1",
|
||||||
|
)
|
||||||
|
|
||||||
|
example_list = []
|
||||||
|
for cls, similar_samples in similar_per_class.items():
|
||||||
|
for global_idx, similarity, metadata in similar_samples:
|
||||||
|
example = self._find_example_by_metadata(metadata)
|
||||||
|
if example is not None:
|
||||||
|
example_list.append(example)
|
||||||
|
|
||||||
|
if len(example_list) == 0:
|
||||||
|
print(f"[WARNING] No similar examples found, falling back to random")
|
||||||
|
return self.sample_examples()
|
||||||
|
|
||||||
|
return datasets.Dataset.from_list(example_list)
|
||||||
|
|
||||||
|
def _find_example_by_metadata(self, metadata):
|
||||||
|
"""Find an example in the example_dataset by its metadata using O(1) lookup."""
|
||||||
|
user_id = str(metadata["user_id"])
|
||||||
|
idx = int(metadata["idx"])
|
||||||
|
|
||||||
|
key = (user_id, idx)
|
||||||
|
if key not in self._example_lookup:
|
||||||
|
return None
|
||||||
|
|
||||||
|
dataset_idx = self._example_lookup[key]
|
||||||
|
example = self.example_dataset[dataset_idx]
|
||||||
|
return {
|
||||||
|
"user_id": example["user_id"],
|
||||||
|
"session_id": example["session_id"],
|
||||||
|
"idx": example["idx"],
|
||||||
|
"label": example["label"],
|
||||||
|
"features": example["features"],
|
||||||
|
"data": example.get("data", {}),
|
||||||
|
}
|
||||||
|
|
||||||
def sample_examples(self):
|
def sample_examples(self):
|
||||||
example_dataset = datasets.Dataset.from_list([])
|
example_dataset = datasets.Dataset.from_list([])
|
||||||
|
|||||||
190
run.py
190
run.py
@@ -5,7 +5,7 @@ import yaml
|
|||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Optional
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
@@ -14,11 +14,47 @@ from fire import Fire
|
|||||||
from core.model import load_models
|
from core.model import load_models
|
||||||
from core.data_loader import DataLoader
|
from core.data_loader import DataLoader
|
||||||
from core.sensing_agent import SensingAgent
|
from core.sensing_agent import SensingAgent
|
||||||
|
from core.embedding_index import EmbeddingIndex, create_embedding_index
|
||||||
|
|
||||||
|
_EMBEDDING_INDEX: Optional[EmbeddingIndex] = None
|
||||||
|
|
||||||
|
def init_embedding_index(embedding_path: Optional[str]) -> Optional[EmbeddingIndex]:
|
||||||
|
"""Initialize global embedding index for similarity-based selection."""
|
||||||
|
global _EMBEDDING_INDEX
|
||||||
|
|
||||||
|
if embedding_path is None:
|
||||||
|
_EMBEDDING_INDEX = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
if _EMBEDDING_INDEX is not None:
|
||||||
|
return _EMBEDDING_INDEX
|
||||||
|
|
||||||
|
_EMBEDDING_INDEX = create_embedding_index(embedding_path)
|
||||||
|
return _EMBEDDING_INDEX
|
||||||
|
|
||||||
|
|
||||||
def load_user_data_sync(data_path, user, seed, log_path_base):
|
def load_user_data_sync(
|
||||||
print(f"[DATA LOADING] Starting data loading for user: {user}, seed: {seed}")
|
data_path: str,
|
||||||
data_loader = DataLoader(data_path, user, selection_criteria="out_random", num_examples=1)
|
user: str,
|
||||||
|
seed: int,
|
||||||
|
log_path_base: str,
|
||||||
|
selection_criteria: str = "out_random",
|
||||||
|
num_examples: int = 1,
|
||||||
|
embedding_index: Optional[EmbeddingIndex] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Load data for a single user with specified selection criteria."""
|
||||||
|
print(f"[DATA LOADING] Starting: user={user}, seed={seed}, criteria={selection_criteria}")
|
||||||
|
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
data_loader = DataLoader(
|
||||||
|
data_path,
|
||||||
|
user,
|
||||||
|
selection_criteria=selection_criteria,
|
||||||
|
num_examples=num_examples,
|
||||||
|
embedding_index=embedding_index,
|
||||||
|
)
|
||||||
if not data_loader.is_valid:
|
if not data_loader.is_valid:
|
||||||
print(f"[DATA LOADING] Skipping invalid user: {user}")
|
print(f"[DATA LOADING] Skipping invalid user: {user}")
|
||||||
return []
|
return []
|
||||||
@@ -52,13 +88,36 @@ def load_user_data_sync(data_path, user, seed, log_path_base):
|
|||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
async def load_data_parallel(config):
|
async def load_data_parallel(config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
"""Load data for all users in parallel with configurable selection criteria."""
|
||||||
print("[DATA LOADING] Starting parallel data loading...")
|
print("[DATA LOADING] Starting parallel data loading...")
|
||||||
|
|
||||||
|
# Get configuration parameters
|
||||||
|
selection_criteria = config.get("selection_criteria", "out_random")
|
||||||
|
embedding_path = config.get("embedding_path", None)
|
||||||
|
num_examples = config.get("num_examples", 1)
|
||||||
|
|
||||||
|
# Initialize embedding index if needed for similarity-based selection
|
||||||
|
embedding_index = None
|
||||||
|
if selection_criteria in ["out_similar", "in_similar"]:
|
||||||
|
if embedding_path is None:
|
||||||
|
print(f"[WARNING] {selection_criteria} requires embedding_path in config")
|
||||||
|
print("[WARNING] Falling back to random selection")
|
||||||
|
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
|
||||||
|
else:
|
||||||
|
embedding_index = init_embedding_index(embedding_path)
|
||||||
|
if embedding_index is None:
|
||||||
|
print(f"[WARNING] Failed to load embeddings, falling back to random")
|
||||||
|
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
|
||||||
|
|
||||||
user_paths = glob(os.path.join(config["data_path"], "*"))
|
user_paths = glob(os.path.join(config["data_path"], "*"))
|
||||||
user_paths = [path for path in user_paths if os.path.isdir(path)]
|
user_paths = [path for path in user_paths if os.path.isdir(path)]
|
||||||
users = [path.split("/")[-1] for path in user_paths]
|
users = [path.split("/")[-1] for path in user_paths]
|
||||||
|
|
||||||
print(f"[DATA LOADING] Found {len(users)} users: {users}")
|
print(f"[DATA LOADING] Found {len(users)} users: {users}")
|
||||||
|
print(f"[DATA LOADING] Selection criteria: {selection_criteria}")
|
||||||
|
print(f"[DATA LOADING] Num examples per class: {num_examples}")
|
||||||
|
|
||||||
max_workers = config.get("data_workers", 96)
|
max_workers = config.get("data_workers", 96)
|
||||||
print(f"[DATA LOADING] Using {max_workers} workers for data loading")
|
print(f"[DATA LOADING] Using {max_workers} workers for data loading")
|
||||||
|
|
||||||
@@ -74,7 +133,10 @@ async def load_data_parallel(config):
|
|||||||
config["data_path"],
|
config["data_path"],
|
||||||
user,
|
user,
|
||||||
seed,
|
seed,
|
||||||
config["log_path"]
|
config["log_path"],
|
||||||
|
selection_criteria,
|
||||||
|
num_examples,
|
||||||
|
embedding_index,
|
||||||
)
|
)
|
||||||
futures.append(future)
|
futures.append(future)
|
||||||
|
|
||||||
@@ -89,10 +151,14 @@ async def load_data_parallel(config):
|
|||||||
else:
|
else:
|
||||||
all_tasks.extend(result)
|
all_tasks.extend(result)
|
||||||
|
|
||||||
|
print(f"[DATA LOADING] Total tasks: {len(all_tasks)}")
|
||||||
return all_tasks
|
return all_tasks
|
||||||
|
|
||||||
|
|
||||||
async def run_parallel(kwargs_list, model_pool, config):
|
async def run_parallel(kwargs_list: List[Dict[str, Any]], model_pool, config: Dict[str, Any]) -> None:
|
||||||
|
"""Run classification tasks in parallel using the model pool."""
|
||||||
|
print(f"[EXECUTION] Starting {len(kwargs_list)} classification tasks...")
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
for kwargs in kwargs_list:
|
for kwargs in kwargs_list:
|
||||||
agent = SensingAgent(
|
agent = SensingAgent(
|
||||||
@@ -108,15 +174,121 @@ async def run_parallel(kwargs_list, model_pool, config):
|
|||||||
task = asyncio.create_task(agent.solve(kwargs["sample"], kwargs["examples"], kwargs["ground_truth"]))
|
task = asyncio.create_task(agent.solve(kwargs["sample"], kwargs["examples"], kwargs["ground_truth"]))
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
print("[EXECUTION] All tasks completed")
|
||||||
|
|
||||||
|
|
||||||
def run(config_path):
|
def run(config_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Main entry point for running experiments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to YAML configuration file
|
||||||
|
"""
|
||||||
|
print(f"[MAIN] Loading config from: {config_path}")
|
||||||
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("EXPERIMENT CONFIGURATION")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f" Data path: {config.get('data_path', 'N/A')}")
|
||||||
|
print(f" Log path: {config.get('log_path', 'N/A')}")
|
||||||
|
print(f" Selection criteria: {config.get('selection_criteria', 'out_random')}")
|
||||||
|
print(f" Num examples: {config.get('num_examples', 1)}")
|
||||||
|
print(f" Num seeds: {config.get('num_seeds', 1)}")
|
||||||
|
print(f" Embedding path: {config.get('embedding_path', 'N/A')}")
|
||||||
|
print(f" Num models: {len(config.get('models', []))}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
model_pool = load_models(config["models"])
|
model_pool = load_models(config["models"])
|
||||||
|
|
||||||
kwargs_list = asyncio.run(load_data_parallel(config))
|
kwargs_list = asyncio.run(load_data_parallel(config))
|
||||||
|
|
||||||
|
if len(kwargs_list) == 0:
|
||||||
|
print("[ERROR] No valid tasks to run. Check data paths and configuration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Warmup models
|
||||||
|
print("[MAIN] Warming up models...")
|
||||||
model_pool.warmup()
|
model_pool.warmup()
|
||||||
|
|
||||||
|
# Run experiments
|
||||||
|
print("[MAIN] Starting experiments...")
|
||||||
|
start_time = time.time()
|
||||||
asyncio.run(run_parallel(kwargs_list, model_pool, config))
|
asyncio.run(run_parallel(kwargs_list, model_pool, config))
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
print(f"[MAIN] Experiment completed in {elapsed:.2f} seconds")
|
||||||
|
print(f"[MAIN] Results saved to: {config['log_path']}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_comparison(
|
||||||
|
base_config_path: str,
|
||||||
|
criteria_list: str = "out_random,in_random,out_similar,in_similar",
|
||||||
|
embedding_path: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Run experiments comparing multiple selection criteria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_config_path: Path to base YAML config file
|
||||||
|
criteria_list: Comma-separated list of selection criteria to compare
|
||||||
|
embedding_path: Path to embeddings (required for *_similar criteria)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
python run.py compare config/sleepedf.yaml \\
|
||||||
|
--criteria_list="out_random,out_similar" \\
|
||||||
|
--embedding_path="./embeddings_full"
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_config = yaml.load(open(base_config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||||
|
criteria = [c.strip() for c in criteria_list.split(",")]
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("COMPARISON EXPERIMENT")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f" Selection criteria to compare: {criteria}")
|
||||||
|
print(f" Embedding path: {embedding_path}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for criterion in criteria:
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Running experiment: {criterion}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
config = base_config.copy()
|
||||||
|
config["selection_criteria"] = criterion
|
||||||
|
|
||||||
|
base_log_path = config["log_path"]
|
||||||
|
if not base_log_path.endswith(criterion):
|
||||||
|
config["log_path"] = f"{base_log_path}_{criterion}"
|
||||||
|
|
||||||
|
if criterion in ["out_similar", "in_similar"]:
|
||||||
|
if embedding_path:
|
||||||
|
config["embedding_path"] = embedding_path
|
||||||
|
else:
|
||||||
|
print(f"[WARNING] {criterion} requires --embedding_path argument")
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_pool = load_models(config["models"])
|
||||||
|
|
||||||
|
kwargs_list = asyncio.run(load_data_parallel(config))
|
||||||
|
|
||||||
|
if len(kwargs_list) == 0:
|
||||||
|
print(f"[WARNING] No tasks for {criterion}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_pool.warmup()
|
||||||
|
asyncio.run(run_parallel(kwargs_list, model_pool, config))
|
||||||
|
|
||||||
|
print(f"[DONE] Results for {criterion} saved to: {config['log_path']}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("ALL COMPARISON EXPERIMENTS COMPLETED")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Fire(run)
|
Fire({
|
||||||
|
"run": run,
|
||||||
|
"compare": run_comparison,
|
||||||
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user