Removing Embedding Functions
This commit is contained in:
207
sc/run_sc.py
207
sc/run_sc.py
@@ -6,22 +6,8 @@ This module implements Self-Consistency methodology for sleep stage classificati
|
||||
- Use majority voting for final answer
|
||||
- Support confidence-based tie-breaking
|
||||
|
||||
Selection Criteria:
|
||||
- out_random: Random selection from different users (baseline)
|
||||
- in_random: Random selection from same user (personalization baseline)
|
||||
- out_similar: Chronos-2 embedding similarity-based selection
|
||||
- out_metadata: Gower distance-based selection (gender, age)
|
||||
|
||||
Usage:
|
||||
# Run single experiment
|
||||
python -m sc.run_sc run sc/config/sleepedf_sc.yaml
|
||||
|
||||
# Quick test with limited tasks
|
||||
python -m sc.run_sc test sc/config/sleepedf_sc.yaml --max_tasks=5
|
||||
|
||||
# Run from project root
|
||||
cd /home/ssum/tsllm_personalization_icl
|
||||
python -m sc.run_sc run sc/config/sleepedf_sc.yaml
|
||||
python -m sc.run_sc sc/config/sleepedf_sc.yaml
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -42,75 +28,10 @@ from fire import Fire
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.data_loader import DataLoader
|
||||
from core.embedding_index import EmbeddingIndex, create_embedding_index
|
||||
from core.metadata_index import MetadataIndex, create_metadata_index
|
||||
from sc.core.sc_agent import SelfConsistencyAgent
|
||||
from sc.core.model_utils import load_models_with_temperature
|
||||
|
||||
|
||||
# Valid selection criteria
|
||||
VALID_CRITERIA = ["out_random", "in_random", "out_similar", "out_metadata"]
|
||||
|
||||
# Global indices
|
||||
_EMBEDDING_INDEX: Optional[EmbeddingIndex] = None
|
||||
_METADATA_INDEX: Optional[MetadataIndex] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Index Initialization
|
||||
# =============================================================================
|
||||
|
||||
def init_embedding_index(embedding_path: str) -> Optional[EmbeddingIndex]:
|
||||
"""
|
||||
Initialize embedding index for out_similar selection.
|
||||
|
||||
Args:
|
||||
embedding_path: Path to pre-computed Chronos-2 embeddings
|
||||
|
||||
Returns:
|
||||
EmbeddingIndex instance or None if initialization fails
|
||||
"""
|
||||
global _EMBEDDING_INDEX
|
||||
|
||||
if _EMBEDDING_INDEX is not None:
|
||||
return _EMBEDDING_INDEX
|
||||
|
||||
_EMBEDDING_INDEX = create_embedding_index(embedding_path)
|
||||
return _EMBEDDING_INDEX
|
||||
|
||||
|
||||
def init_metadata_index(
|
||||
data_path: str,
|
||||
metadata_path: str,
|
||||
weight_gender: float = 1.0,
|
||||
weight_age: float = 1.0,
|
||||
) -> Optional[MetadataIndex]:
|
||||
"""
|
||||
Initialize metadata index for out_metadata selection.
|
||||
|
||||
Args:
|
||||
data_path: Path to SleepEDF dataset
|
||||
metadata_path: Path to SC-subjects.xls metadata file
|
||||
weight_gender: Weight for gender distance in Gower calculation
|
||||
weight_age: Weight for age distance in Gower calculation
|
||||
|
||||
Returns:
|
||||
MetadataIndex instance or None if initialization fails
|
||||
"""
|
||||
global _METADATA_INDEX
|
||||
|
||||
if _METADATA_INDEX is not None:
|
||||
return _METADATA_INDEX
|
||||
|
||||
_METADATA_INDEX = create_metadata_index(
|
||||
data_path=data_path,
|
||||
metadata_path=metadata_path,
|
||||
weight_gender=weight_gender,
|
||||
weight_age=weight_age,
|
||||
)
|
||||
return _METADATA_INDEX
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading
|
||||
# =============================================================================
|
||||
@@ -120,11 +41,8 @@ def load_user_data_sync(
|
||||
user: str,
|
||||
seed: int,
|
||||
log_path_base: str,
|
||||
selection_criteria: str = "out_random",
|
||||
num_examples: int = 1,
|
||||
sample_rate: int = 10,
|
||||
embedding_index: Optional[EmbeddingIndex] = None,
|
||||
metadata_index: Optional[MetadataIndex] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load data for a single user synchronously.
|
||||
@@ -134,26 +52,21 @@ def load_user_data_sync(
|
||||
user: User ID
|
||||
seed: Random seed for reproducibility
|
||||
log_path_base: Base path for logging
|
||||
selection_criteria: Example selection strategy
|
||||
num_examples: Number of examples per class for ICL
|
||||
num_examples: Number of ICL examples per class
|
||||
sample_rate: Process every Nth sample (1=all, 10=10%)
|
||||
embedding_index: EmbeddingIndex for out_similar (optional)
|
||||
metadata_index: MetadataIndex for out_metadata (optional)
|
||||
|
||||
Returns:
|
||||
List of task dictionaries containing sample info, examples, and metadata
|
||||
List of task dictionaries
|
||||
"""
|
||||
print(f"[DATA] Loading: user={user}, seed={seed}, criteria={selection_criteria}")
|
||||
print(f"[DATA] Loading: user={user}, seed={seed}")
|
||||
|
||||
np.random.seed(seed)
|
||||
|
||||
data_loader = DataLoader(
|
||||
data_path=data_path,
|
||||
user_id=user,
|
||||
selection_criteria=selection_criteria,
|
||||
selection_criteria="out_random",
|
||||
num_examples=num_examples,
|
||||
embedding_index=embedding_index,
|
||||
metadata_index=metadata_index,
|
||||
)
|
||||
|
||||
if not data_loader.is_valid:
|
||||
@@ -197,55 +110,19 @@ async def load_data_parallel(config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
config: Experiment configuration dictionary
|
||||
|
||||
Returns:
|
||||
List of all task dictionaries across all users and seeds
|
||||
List of task dictionaries for all users and seeds
|
||||
"""
|
||||
print("[DATA] Starting parallel data loading...")
|
||||
|
||||
# Extract parameters from config
|
||||
data_path = config["data_path"]
|
||||
selection_criteria = config.get("selection_criteria", "out_random")
|
||||
num_examples = config.get("num_examples", 1)
|
||||
sample_rate = config.get("sample_rate", 10)
|
||||
|
||||
# Validate selection criteria
|
||||
if selection_criteria not in VALID_CRITERIA:
|
||||
print(f"[WARNING] Invalid criteria '{selection_criteria}', using out_random")
|
||||
selection_criteria = "out_random"
|
||||
|
||||
# Initialize required indices based on selection criteria
|
||||
embedding_index = None
|
||||
metadata_index = None
|
||||
|
||||
if selection_criteria == "out_similar":
|
||||
embedding_path = config.get("embedding_path")
|
||||
if embedding_path:
|
||||
embedding_index = init_embedding_index(embedding_path)
|
||||
if embedding_index is None:
|
||||
print("[WARNING] out_similar requires embedding_path, falling back to out_random")
|
||||
selection_criteria = "out_random"
|
||||
|
||||
elif selection_criteria == "out_metadata":
|
||||
metadata_path = config.get("metadata_path")
|
||||
weight_gender = config.get("weight_gender", 1.0)
|
||||
weight_age = config.get("weight_age", 1.0)
|
||||
|
||||
if metadata_path:
|
||||
metadata_index = init_metadata_index(
|
||||
data_path=data_path,
|
||||
metadata_path=metadata_path,
|
||||
weight_gender=weight_gender,
|
||||
weight_age=weight_age,
|
||||
)
|
||||
if metadata_index is None:
|
||||
print("[WARNING] out_metadata requires metadata_path, falling back to out_random")
|
||||
selection_criteria = "out_random"
|
||||
|
||||
# Collect user list from data directory
|
||||
# Collect user list
|
||||
user_paths = glob(os.path.join(data_path, "*"))
|
||||
users = [os.path.basename(p) for p in user_paths if os.path.isdir(p)]
|
||||
|
||||
print(f"[DATA] Found {len(users)} users")
|
||||
print(f"[DATA] Selection criteria: {selection_criteria}")
|
||||
print(f"[DATA] Num examples per class: {num_examples}")
|
||||
print(f"[DATA] Sample rate: 1/{sample_rate}")
|
||||
|
||||
@@ -266,11 +143,8 @@ async def load_data_parallel(config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
user,
|
||||
seed,
|
||||
config["log_path"],
|
||||
selection_criteria,
|
||||
num_examples,
|
||||
sample_rate,
|
||||
embedding_index,
|
||||
metadata_index,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
@@ -357,12 +231,10 @@ async def run_parallel(
|
||||
print(f"[RUN] Self-Consistency samples per task: {num_sc_samples}")
|
||||
|
||||
# Execute all tasks asynchronously
|
||||
async_tasks = []
|
||||
for task in tasks:
|
||||
async_task = asyncio.create_task(
|
||||
run_single_task(task, model_pool, num_sc_samples)
|
||||
)
|
||||
async_tasks.append(async_task)
|
||||
async_tasks = [
|
||||
asyncio.create_task(run_single_task(task, model_pool, num_sc_samples))
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*async_tasks, return_exceptions=True)
|
||||
|
||||
@@ -417,9 +289,10 @@ def compute_statistics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
if r.get("is_correct", False):
|
||||
class_correct[gt] = class_correct.get(gt, 0) + 1
|
||||
|
||||
class_accuracy = {}
|
||||
for cls in class_total:
|
||||
class_accuracy[cls] = class_correct.get(cls, 0) / class_total[cls]
|
||||
class_accuracy = {
|
||||
cls: class_correct.get(cls, 0) / class_total[cls]
|
||||
for cls in class_total
|
||||
}
|
||||
|
||||
# High consistency accuracy analysis
|
||||
high_consistency_results = [r for r in results if r.get("consistency", 0) >= 0.8]
|
||||
@@ -484,7 +357,7 @@ def save_results(
|
||||
# CLI Commands
|
||||
# =============================================================================
|
||||
|
||||
def run(config_path: str) -> None:
|
||||
def main(config_path: str) -> None:
|
||||
"""
|
||||
Run Self-Consistency experiment.
|
||||
|
||||
@@ -492,7 +365,7 @@ def run(config_path: str) -> None:
|
||||
config_path: Path to YAML configuration file
|
||||
|
||||
Example:
|
||||
python -m sc.run_sc run sc/config/sleepedf_sc.yaml
|
||||
python -m sc.run_sc sc/config/sleepedf_sc.yaml
|
||||
"""
|
||||
print(f"[MAIN] Loading config: {config_path}")
|
||||
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||
@@ -508,7 +381,6 @@ def run(config_path: str) -> None:
|
||||
print("=" * 60)
|
||||
print(f" Data path: {config.get('data_path', 'N/A')}")
|
||||
print(f" Log path: {config.get('log_path', 'N/A')}")
|
||||
print(f" Selection criteria: {config.get('selection_criteria', 'out_random')}")
|
||||
print(f" Num ICL examples: {config.get('num_examples', 1)}")
|
||||
print(f" Num seeds: {config.get('num_seeds', 1)}")
|
||||
print(f" Num SC samples: {config.get('num_sc_samples', 5)}")
|
||||
@@ -562,48 +434,5 @@ def run(config_path: str) -> None:
|
||||
print(f"[MAIN] Results saved to: {config['log_path']}")
|
||||
|
||||
|
||||
def test(config_path: str, max_tasks: int = 5) -> None:
|
||||
"""
|
||||
Run quick test with limited tasks for debugging.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
max_tasks: Maximum number of tasks to run
|
||||
|
||||
Example:
|
||||
python -m sc.run_sc test sc/config/sleepedf_sc.yaml --max_tasks=5
|
||||
"""
|
||||
print(f"[TEST] Running quick test with max {max_tasks} tasks")
|
||||
|
||||
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||
config["num_seeds"] = 1 # Use single seed for testing
|
||||
config["log_path"] = config.get("log_path", "./test_results") + "_test"
|
||||
|
||||
# Load models with specified temperature
|
||||
temperature = config.get("temperature", 0.0)
|
||||
model_pool = load_models_with_temperature(config["models"], temperature=temperature)
|
||||
|
||||
# Load data (limited)
|
||||
tasks = asyncio.run(load_data_parallel(config))
|
||||
tasks = tasks[:max_tasks]
|
||||
|
||||
if len(tasks) == 0:
|
||||
print("[ERROR] No valid tasks.")
|
||||
return
|
||||
|
||||
print(f"[TEST] Running {len(tasks)} tasks...")
|
||||
model_pool.warmup()
|
||||
|
||||
results = asyncio.run(run_parallel(tasks, model_pool, config))
|
||||
|
||||
# Display brief results
|
||||
stats = compute_statistics(results)
|
||||
print(f"\n[TEST] Accuracy: {stats.get('accuracy', 0):.4f}")
|
||||
print(f"[TEST] Avg Consistency: {stats.get('avg_consistency', 0):.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire({
|
||||
"run": run,
|
||||
"test": test,
|
||||
})
|
||||
Fire(main)
|
||||
|
||||
Reference in New Issue
Block a user