260119:chronos_2_based_embedding_running
This commit is contained in:
@@ -37,8 +37,9 @@ Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from glob import glob
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -48,6 +49,10 @@ from chronos import BaseChronosPipeline, Chronos2Pipeline
|
||||
from fire import Fire
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
# Add parent directories to path for importing metadata loader
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
from preprocess.sleepedf_metadata import SleepEDFMetadata, DEFAULT_METADATA_PATH
|
||||
|
||||
# pipeline: Chronos2Pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-2", device_map="cuda")
|
||||
|
||||
# =============================================================================
|
||||
@@ -234,7 +239,8 @@ class Chronos_2_Embedder:
|
||||
def extract_embeddings(
|
||||
self,
|
||||
data_root: str,
|
||||
batch_size: int = 32
|
||||
batch_size: int = 32,
|
||||
metadata_path: str = DEFAULT_METADATA_PATH,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Extract embeddings from all sessions under the data root directory.
|
||||
@@ -247,20 +253,34 @@ class Chronos_2_Embedder:
|
||||
batch_size: Number of samples to process together.
|
||||
Larger = faster but more memory.
|
||||
32 is a good balance for most GPUs.
|
||||
metadata_path: Path to SC-subjects.xls for gender/age info
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with columns:
|
||||
- user_id, session_id, idx, label (metadata)
|
||||
- gender, age (demographic metadata)
|
||||
- embedding (vector; dim depends on variate_fusion: 1024 for concat, 512 for mean)
|
||||
"""
|
||||
session_paths = self.discover_session_paths(data_root)
|
||||
print(f"[INFO] Discovered {len(session_paths)} sessions")
|
||||
|
||||
# Load metadata for gender/age information
|
||||
try:
|
||||
metadata = SleepEDFMetadata(metadata_path)
|
||||
has_metadata = True
|
||||
except FileNotFoundError:
|
||||
print(f"[WARNING] Metadata file not found: {metadata_path}")
|
||||
print(f"[WARNING] Gender/age information will not be available")
|
||||
metadata = None
|
||||
has_metadata = False
|
||||
|
||||
all_embeddings = []
|
||||
all_user_ids = []
|
||||
all_session_ids = []
|
||||
all_idxs = []
|
||||
all_labels = []
|
||||
all_genders = []
|
||||
all_ages = []
|
||||
|
||||
for user_id, session_id, session_path in session_paths:
|
||||
# Load HuggingFace dataset from disk
|
||||
@@ -282,22 +302,43 @@ class Chronos_2_Embedder:
|
||||
|
||||
# Collect embeddings and metadata
|
||||
for i in range(embeddings.shape[0]):
|
||||
user_id_str = str(batch["user_id"][i])
|
||||
all_embeddings.append(embeddings[i].tolist())
|
||||
all_user_ids.append(str(batch["user_id"][i]))
|
||||
all_user_ids.append(user_id_str)
|
||||
all_session_ids.append(str(batch["session_id"][i]))
|
||||
all_idxs.append(int(batch["idx"][i]))
|
||||
all_labels.append(str(batch["label"][i]))
|
||||
|
||||
# Add demographic metadata
|
||||
if has_metadata:
|
||||
info = metadata.get_info(user_id_str)
|
||||
if info:
|
||||
all_genders.append(info['gender'])
|
||||
all_ages.append(info['age'])
|
||||
else:
|
||||
all_genders.append('Unknown')
|
||||
all_ages.append(-1)
|
||||
else:
|
||||
all_genders.append('Unknown')
|
||||
all_ages.append(-1)
|
||||
|
||||
# Create HuggingFace Dataset
|
||||
result_dataset = Dataset.from_dict({
|
||||
"user_id": all_user_ids,
|
||||
"session_id": all_session_ids,
|
||||
"idx": all_idxs,
|
||||
"label": all_labels,
|
||||
"gender": all_genders,
|
||||
"age": all_ages,
|
||||
"embedding": all_embeddings,
|
||||
})
|
||||
|
||||
print(f"[INFO] Total samples: {len(result_dataset)}")
|
||||
if has_metadata:
|
||||
gender_counts = {}
|
||||
for g in all_genders:
|
||||
gender_counts[g] = gender_counts.get(g, 0) + 1
|
||||
print(f"[INFO] Gender distribution in extracted data: {gender_counts}")
|
||||
return result_dataset
|
||||
|
||||
def save_embeddings(
|
||||
@@ -460,6 +501,8 @@ class CLI:
|
||||
users: str = None,
|
||||
num_users: int = 0,
|
||||
labels: str = None,
|
||||
gender: str = None,
|
||||
metadata_path: str = DEFAULT_METADATA_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize embeddings with t-SNE.
|
||||
@@ -470,7 +513,9 @@ class CLI:
|
||||
perplexity: t-SNE perplexity parameter (default: 30.0)
|
||||
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||
num_users: Include only first N users, 0 = all (default: 0)
|
||||
labels: Comma-separated sleep stage labels to include (e.g., '0,1,2')
|
||||
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||
gender: Filter by gender ('M', 'F', or None for all)
|
||||
metadata_path: Path to metadata file for gender lookup (if not in dataset)
|
||||
"""
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
@@ -489,6 +534,11 @@ class CLI:
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||
|
||||
# Filter by gender
|
||||
if gender:
|
||||
dataset = dataset.filter(lambda x: x.get("gender", "Unknown") == gender)
|
||||
print(f"[INFO] Filtered to gender: {gender}")
|
||||
|
||||
# Filter by sleep stage labels
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
@@ -497,6 +547,13 @@ class CLI:
|
||||
|
||||
print(f"[INFO] Total samples: {len(dataset)}")
|
||||
|
||||
# Print gender distribution
|
||||
if "gender" in dataset.column_names:
|
||||
gender_counts = {}
|
||||
for g in dataset["gender"]:
|
||||
gender_counts[g] = gender_counts.get(g, 0) + 1
|
||||
print(f"[INFO] Gender distribution: {gender_counts}")
|
||||
|
||||
# Extract embeddings as numpy array for t-SNE
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
num_seeds: 10
|
||||
num_examples: 1
|
||||
|
||||
# Selection criteria: out_random | in_random | out_similar
|
||||
selection_criteria: "out_random"
|
||||
|
||||
# Embedding path (not required for random criteria)
|
||||
# embedding_path: "/home/ssum/tsllm_personalization_icl/embeddings_20users"
|
||||
|
||||
models:
|
||||
- ollama:url:joy.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:joy.kaist.ac.kr:11438/gpt-oss:20b
|
||||
@@ -17,4 +25,8 @@ models:
|
||||
- ollama:url:iu.kaist.ac.kr:11442/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11443/gpt-oss:20b
|
||||
- ollama:url:iu.kaist.ac.kr:11444/gpt-oss:20b
|
||||
log_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF_out_random"
|
||||
|
||||
# chronos_base : "/mnt/sting/ssum/sleepedf_chronos_base_result"
|
||||
# out_random : "/mnt/sting/ssum/sleepedf_chronos_result_outrandom"
|
||||
# in_random : "/mnt/sting/ssum/sleepedf_chronos_result"
|
||||
log_path: "/mnt/sting/ssum/sleepedf_chronos_result"
|
||||
@@ -3,11 +3,25 @@ import json
|
||||
import datasets
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .embedding_index import EmbeddingIndex
|
||||
|
||||
|
||||
class DataLoader:
|
||||
def __init__(self, data_path, user_id, selection_criteria="out_random", num_examples=1):
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
user_id,
|
||||
selection_criteria="out_random",
|
||||
num_examples=1,
|
||||
embedding_index: Optional["EmbeddingIndex"] = None,
|
||||
):
|
||||
self.is_valid = False
|
||||
self.embedding_index = embedding_index
|
||||
self.data_path = data_path
|
||||
|
||||
if not os.path.exists(os.path.join(data_path, "info.json")):
|
||||
return
|
||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
|
||||
@@ -15,6 +29,10 @@ class DataLoader:
|
||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
|
||||
return
|
||||
|
||||
if selection_criteria in ["out_similar", "in_similar"] and embedding_index is None:
|
||||
print(f"[WARNING] {selection_criteria} requires embedding_index, falling back to random")
|
||||
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
|
||||
|
||||
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
|
||||
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
|
||||
self.example_dataset = datasets.Dataset.from_list([])
|
||||
@@ -34,9 +52,19 @@ class DataLoader:
|
||||
self.num_examples = num_examples
|
||||
|
||||
self.classes = sorted(list(self.metadata["class"].keys()))
|
||||
self.selected_examples = self.sample_examples()
|
||||
if self.selected_examples is None:
|
||||
return
|
||||
|
||||
# Build lookup index for fast example retrieval: (user_id, idx) -> dataset_index
|
||||
self._example_lookup = {}
|
||||
for i, example in enumerate(self.example_dataset):
|
||||
key = (str(example["user_id"]), int(example["idx"]))
|
||||
self._example_lookup[key] = i
|
||||
|
||||
if selection_criteria in ["out_similar", "in_similar"]:
|
||||
self.selected_examples = None
|
||||
else:
|
||||
self.selected_examples = self.sample_examples()
|
||||
if self.selected_examples is None:
|
||||
return
|
||||
|
||||
self.is_valid = True
|
||||
|
||||
@@ -44,11 +72,95 @@ class DataLoader:
|
||||
return len(self.test_dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.test_dataset[idx], self.selected_examples
|
||||
sample = self.test_dataset[idx]
|
||||
if self.selection_criteria in ["out_similar", "in_similar"]:
|
||||
examples = self.sample_similar_examples(sample)
|
||||
else:
|
||||
examples = self.selected_examples
|
||||
return sample, examples
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.test_dataset:
|
||||
yield sample, self.selected_examples
|
||||
if self.selection_criteria in ["out_similar", "in_similar"]:
|
||||
examples = self.sample_similar_examples(sample)
|
||||
else:
|
||||
examples = self.selected_examples
|
||||
yield sample, examples
|
||||
|
||||
def sample_similar_examples(self, sample):
|
||||
"""
|
||||
Sample examples based on embedding similarity to the given sample.
|
||||
|
||||
For each class, finds the most similar example from the example dataset
|
||||
using Chronos-2 embeddings.
|
||||
|
||||
Args:
|
||||
sample: The test sample to find similar examples for
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset containing similar examples (one per class)
|
||||
"""
|
||||
if self.embedding_index is None:
|
||||
return self.sample_examples()
|
||||
|
||||
query_embedding = self.embedding_index.get_embedding_by_key(
|
||||
user_id=str(sample["user_id"]),
|
||||
session_id=str(sample["session_id"]),
|
||||
idx=int(sample["idx"]),
|
||||
)
|
||||
|
||||
if query_embedding is None:
|
||||
print(f"[WARNING] No embedding found for sample {sample['user_id']}/{sample['session_id']}/{sample['idx']}")
|
||||
return self.sample_examples()
|
||||
|
||||
if self.selection_criteria == "out_similar":
|
||||
exclude_user = str(self.user_id)
|
||||
include_user = None
|
||||
else:
|
||||
exclude_user = None
|
||||
include_user = str(self.user_id)
|
||||
|
||||
similar_per_class = self.embedding_index.find_similar_per_class(
|
||||
query_embedding=query_embedding,
|
||||
classes=self.classes,
|
||||
k_per_class=self.num_examples,
|
||||
exclude_user=exclude_user,
|
||||
include_user=include_user,
|
||||
filter_session="1",
|
||||
)
|
||||
|
||||
example_list = []
|
||||
for cls, similar_samples in similar_per_class.items():
|
||||
for global_idx, similarity, metadata in similar_samples:
|
||||
example = self._find_example_by_metadata(metadata)
|
||||
if example is not None:
|
||||
example_list.append(example)
|
||||
|
||||
if len(example_list) == 0:
|
||||
print(f"[WARNING] No similar examples found, falling back to random")
|
||||
return self.sample_examples()
|
||||
|
||||
return datasets.Dataset.from_list(example_list)
|
||||
|
||||
def _find_example_by_metadata(self, metadata):
|
||||
"""Find an example in the example_dataset by its metadata using O(1) lookup."""
|
||||
user_id = str(metadata["user_id"])
|
||||
idx = int(metadata["idx"])
|
||||
|
||||
key = (user_id, idx)
|
||||
if key not in self._example_lookup:
|
||||
return None
|
||||
|
||||
dataset_idx = self._example_lookup[key]
|
||||
example = self.example_dataset[dataset_idx]
|
||||
return {
|
||||
"user_id": example["user_id"],
|
||||
"session_id": example["session_id"],
|
||||
"idx": example["idx"],
|
||||
"label": example["label"],
|
||||
"features": example["features"],
|
||||
"data": example.get("data", {}),
|
||||
}
|
||||
|
||||
def sample_examples(self):
|
||||
example_dataset = datasets.Dataset.from_list([])
|
||||
|
||||
190
run.py
190
run.py
@@ -5,7 +5,7 @@ import yaml
|
||||
import json
|
||||
import numpy as np
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Optional
|
||||
import time
|
||||
|
||||
from glob import glob
|
||||
@@ -14,11 +14,47 @@ from fire import Fire
|
||||
from core.model import load_models
|
||||
from core.data_loader import DataLoader
|
||||
from core.sensing_agent import SensingAgent
|
||||
from core.embedding_index import EmbeddingIndex, create_embedding_index
|
||||
|
||||
_EMBEDDING_INDEX: Optional[EmbeddingIndex] = None
|
||||
|
||||
def init_embedding_index(embedding_path: Optional[str]) -> Optional[EmbeddingIndex]:
|
||||
"""Initialize global embedding index for similarity-based selection."""
|
||||
global _EMBEDDING_INDEX
|
||||
|
||||
if embedding_path is None:
|
||||
_EMBEDDING_INDEX = None
|
||||
return None
|
||||
|
||||
if _EMBEDDING_INDEX is not None:
|
||||
return _EMBEDDING_INDEX
|
||||
|
||||
_EMBEDDING_INDEX = create_embedding_index(embedding_path)
|
||||
return _EMBEDDING_INDEX
|
||||
|
||||
|
||||
def load_user_data_sync(data_path, user, seed, log_path_base):
|
||||
print(f"[DATA LOADING] Starting data loading for user: {user}, seed: {seed}")
|
||||
data_loader = DataLoader(data_path, user, selection_criteria="out_random", num_examples=1)
|
||||
def load_user_data_sync(
|
||||
data_path: str,
|
||||
user: str,
|
||||
seed: int,
|
||||
log_path_base: str,
|
||||
selection_criteria: str = "out_random",
|
||||
num_examples: int = 1,
|
||||
embedding_index: Optional[EmbeddingIndex] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Load data for a single user with specified selection criteria."""
|
||||
print(f"[DATA LOADING] Starting: user={user}, seed={seed}, criteria={selection_criteria}")
|
||||
|
||||
# Set random seed for reproducibility
|
||||
np.random.seed(seed)
|
||||
|
||||
data_loader = DataLoader(
|
||||
data_path,
|
||||
user,
|
||||
selection_criteria=selection_criteria,
|
||||
num_examples=num_examples,
|
||||
embedding_index=embedding_index,
|
||||
)
|
||||
if not data_loader.is_valid:
|
||||
print(f"[DATA LOADING] Skipping invalid user: {user}")
|
||||
return []
|
||||
@@ -52,13 +88,36 @@ def load_user_data_sync(data_path, user, seed, log_path_base):
|
||||
return tasks
|
||||
|
||||
|
||||
async def load_data_parallel(config):
|
||||
async def load_data_parallel(config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Load data for all users in parallel with configurable selection criteria."""
|
||||
print("[DATA LOADING] Starting parallel data loading...")
|
||||
|
||||
# Get configuration parameters
|
||||
selection_criteria = config.get("selection_criteria", "out_random")
|
||||
embedding_path = config.get("embedding_path", None)
|
||||
num_examples = config.get("num_examples", 1)
|
||||
|
||||
# Initialize embedding index if needed for similarity-based selection
|
||||
embedding_index = None
|
||||
if selection_criteria in ["out_similar", "in_similar"]:
|
||||
if embedding_path is None:
|
||||
print(f"[WARNING] {selection_criteria} requires embedding_path in config")
|
||||
print("[WARNING] Falling back to random selection")
|
||||
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
|
||||
else:
|
||||
embedding_index = init_embedding_index(embedding_path)
|
||||
if embedding_index is None:
|
||||
print(f"[WARNING] Failed to load embeddings, falling back to random")
|
||||
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
|
||||
|
||||
user_paths = glob(os.path.join(config["data_path"], "*"))
|
||||
user_paths = [path for path in user_paths if os.path.isdir(path)]
|
||||
users = [path.split("/")[-1] for path in user_paths]
|
||||
|
||||
print(f"[DATA LOADING] Found {len(users)} users: {users}")
|
||||
print(f"[DATA LOADING] Selection criteria: {selection_criteria}")
|
||||
print(f"[DATA LOADING] Num examples per class: {num_examples}")
|
||||
|
||||
max_workers = config.get("data_workers", 96)
|
||||
print(f"[DATA LOADING] Using {max_workers} workers for data loading")
|
||||
|
||||
@@ -74,7 +133,10 @@ async def load_data_parallel(config):
|
||||
config["data_path"],
|
||||
user,
|
||||
seed,
|
||||
config["log_path"]
|
||||
config["log_path"],
|
||||
selection_criteria,
|
||||
num_examples,
|
||||
embedding_index,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
@@ -89,10 +151,14 @@ async def load_data_parallel(config):
|
||||
else:
|
||||
all_tasks.extend(result)
|
||||
|
||||
print(f"[DATA LOADING] Total tasks: {len(all_tasks)}")
|
||||
return all_tasks
|
||||
|
||||
|
||||
async def run_parallel(kwargs_list, model_pool, config):
|
||||
async def run_parallel(kwargs_list: List[Dict[str, Any]], model_pool, config: Dict[str, Any]) -> None:
|
||||
"""Run classification tasks in parallel using the model pool."""
|
||||
print(f"[EXECUTION] Starting {len(kwargs_list)} classification tasks...")
|
||||
|
||||
tasks = []
|
||||
for kwargs in kwargs_list:
|
||||
agent = SensingAgent(
|
||||
@@ -108,15 +174,121 @@ async def run_parallel(kwargs_list, model_pool, config):
|
||||
task = asyncio.create_task(agent.solve(kwargs["sample"], kwargs["examples"], kwargs["ground_truth"]))
|
||||
tasks.append(task)
|
||||
await asyncio.gather(*tasks)
|
||||
print("[EXECUTION] All tasks completed")
|
||||
|
||||
|
||||
def run(config_path):
|
||||
def run(config_path: str) -> None:
|
||||
"""
|
||||
Main entry point for running experiments.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
"""
|
||||
print(f"[MAIN] Loading config from: {config_path}")
|
||||
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||
|
||||
print("=" * 60)
|
||||
print("EXPERIMENT CONFIGURATION")
|
||||
print("=" * 60)
|
||||
print(f" Data path: {config.get('data_path', 'N/A')}")
|
||||
print(f" Log path: {config.get('log_path', 'N/A')}")
|
||||
print(f" Selection criteria: {config.get('selection_criteria', 'out_random')}")
|
||||
print(f" Num examples: {config.get('num_examples', 1)}")
|
||||
print(f" Num seeds: {config.get('num_seeds', 1)}")
|
||||
print(f" Embedding path: {config.get('embedding_path', 'N/A')}")
|
||||
print(f" Num models: {len(config.get('models', []))}")
|
||||
print("=" * 60)
|
||||
|
||||
model_pool = load_models(config["models"])
|
||||
|
||||
kwargs_list = asyncio.run(load_data_parallel(config))
|
||||
|
||||
if len(kwargs_list) == 0:
|
||||
print("[ERROR] No valid tasks to run. Check data paths and configuration.")
|
||||
return
|
||||
|
||||
# Warmup models
|
||||
print("[MAIN] Warming up models...")
|
||||
model_pool.warmup()
|
||||
|
||||
# Run experiments
|
||||
print("[MAIN] Starting experiments...")
|
||||
start_time = time.time()
|
||||
asyncio.run(run_parallel(kwargs_list, model_pool, config))
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
print(f"[MAIN] Experiment completed in {elapsed:.2f} seconds")
|
||||
print(f"[MAIN] Results saved to: {config['log_path']}")
|
||||
|
||||
|
||||
def run_comparison(
|
||||
base_config_path: str,
|
||||
criteria_list: str = "out_random,in_random,out_similar,in_similar",
|
||||
embedding_path: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run experiments comparing multiple selection criteria.
|
||||
|
||||
Args:
|
||||
base_config_path: Path to base YAML config file
|
||||
criteria_list: Comma-separated list of selection criteria to compare
|
||||
embedding_path: Path to embeddings (required for *_similar criteria)
|
||||
|
||||
Example:
|
||||
python run.py compare config/sleepedf.yaml \\
|
||||
--criteria_list="out_random,out_similar" \\
|
||||
--embedding_path="./embeddings_full"
|
||||
"""
|
||||
|
||||
base_config = yaml.load(open(base_config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||
criteria = [c.strip() for c in criteria_list.split(",")]
|
||||
|
||||
print("=" * 60)
|
||||
print("COMPARISON EXPERIMENT")
|
||||
print("=" * 60)
|
||||
print(f" Selection criteria to compare: {criteria}")
|
||||
print(f" Embedding path: {embedding_path}")
|
||||
print("=" * 60)
|
||||
|
||||
for criterion in criteria:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Running experiment: {criterion}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
config = base_config.copy()
|
||||
config["selection_criteria"] = criterion
|
||||
|
||||
base_log_path = config["log_path"]
|
||||
if not base_log_path.endswith(criterion):
|
||||
config["log_path"] = f"{base_log_path}_{criterion}"
|
||||
|
||||
if criterion in ["out_similar", "in_similar"]:
|
||||
if embedding_path:
|
||||
config["embedding_path"] = embedding_path
|
||||
else:
|
||||
print(f"[WARNING] {criterion} requires --embedding_path argument")
|
||||
continue
|
||||
|
||||
model_pool = load_models(config["models"])
|
||||
|
||||
kwargs_list = asyncio.run(load_data_parallel(config))
|
||||
|
||||
if len(kwargs_list) == 0:
|
||||
print(f"[WARNING] No tasks for {criterion}, skipping")
|
||||
continue
|
||||
|
||||
model_pool.warmup()
|
||||
asyncio.run(run_parallel(kwargs_list, model_pool, config))
|
||||
|
||||
print(f"[DONE] Results for {criterion} saved to: {config['log_path']}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL COMPARISON EXPERIMENTS COMPLETED")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(run)
|
||||
Fire({
|
||||
"run": run,
|
||||
"compare": run_comparison,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user