Removing Embedding Functions

This commit is contained in:
ssum21
2026-01-23 13:12:17 +09:00
parent 6d0bd3ddcd
commit 622fc629be

View File

@@ -6,22 +6,8 @@ This module implements Self-Consistency methodology for sleep stage classificati
- Use majority voting for final answer
- Support confidence-based tie-breaking
Selection Criteria:
- out_random: Random selection from different users (baseline)
- in_random: Random selection from same user (personalization baseline)
- out_similar: Chronos-2 embedding similarity-based selection
- out_metadata: Gower distance-based selection (gender, age)
Usage:
# Run single experiment
python -m sc.run_sc run sc/config/sleepedf_sc.yaml
# Quick test with limited tasks
python -m sc.run_sc test sc/config/sleepedf_sc.yaml --max_tasks=5
# Run from project root
cd /home/ssum/tsllm_personalization_icl
python -m sc.run_sc run sc/config/sleepedf_sc.yaml
python -m sc.run_sc sc/config/sleepedf_sc.yaml
"""
import os
@@ -42,75 +28,10 @@ from fire import Fire
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.data_loader import DataLoader
from core.embedding_index import EmbeddingIndex, create_embedding_index
from core.metadata_index import MetadataIndex, create_metadata_index
from sc.core.sc_agent import SelfConsistencyAgent
from sc.core.model_utils import load_models_with_temperature
# Valid selection criteria
VALID_CRITERIA = ["out_random", "in_random", "out_similar", "out_metadata"]
# Global indices
_EMBEDDING_INDEX: Optional[EmbeddingIndex] = None
_METADATA_INDEX: Optional[MetadataIndex] = None
# =============================================================================
# Index Initialization
# =============================================================================
def init_embedding_index(embedding_path: str) -> Optional[EmbeddingIndex]:
"""
Initialize embedding index for out_similar selection.
Args:
embedding_path: Path to pre-computed Chronos-2 embeddings
Returns:
EmbeddingIndex instance or None if initialization fails
"""
global _EMBEDDING_INDEX
if _EMBEDDING_INDEX is not None:
return _EMBEDDING_INDEX
_EMBEDDING_INDEX = create_embedding_index(embedding_path)
return _EMBEDDING_INDEX
def init_metadata_index(
data_path: str,
metadata_path: str,
weight_gender: float = 1.0,
weight_age: float = 1.0,
) -> Optional[MetadataIndex]:
"""
Initialize metadata index for out_metadata selection.
Args:
data_path: Path to SleepEDF dataset
metadata_path: Path to SC-subjects.xls metadata file
weight_gender: Weight for gender distance in Gower calculation
weight_age: Weight for age distance in Gower calculation
Returns:
MetadataIndex instance or None if initialization fails
"""
global _METADATA_INDEX
if _METADATA_INDEX is not None:
return _METADATA_INDEX
_METADATA_INDEX = create_metadata_index(
data_path=data_path,
metadata_path=metadata_path,
weight_gender=weight_gender,
weight_age=weight_age,
)
return _METADATA_INDEX
# =============================================================================
# Data Loading
# =============================================================================
@@ -120,11 +41,8 @@ def load_user_data_sync(
user: str,
seed: int,
log_path_base: str,
selection_criteria: str = "out_random",
num_examples: int = 1,
sample_rate: int = 10,
embedding_index: Optional[EmbeddingIndex] = None,
metadata_index: Optional[MetadataIndex] = None,
) -> List[Dict[str, Any]]:
"""
Load data for a single user synchronously.
@@ -134,26 +52,21 @@ def load_user_data_sync(
user: User ID
seed: Random seed for reproducibility
log_path_base: Base path for logging
selection_criteria: Example selection strategy
num_examples: Number of examples per class for ICL
num_examples: Number of ICL examples per class
sample_rate: Process every Nth sample (1=all, 10=10%)
embedding_index: EmbeddingIndex for out_similar (optional)
metadata_index: MetadataIndex for out_metadata (optional)
Returns:
List of task dictionaries containing sample info, examples, and metadata
List of task dictionaries
"""
print(f"[DATA] Loading: user={user}, seed={seed}, criteria={selection_criteria}")
print(f"[DATA] Loading: user={user}, seed={seed}")
np.random.seed(seed)
data_loader = DataLoader(
data_path=data_path,
user_id=user,
selection_criteria=selection_criteria,
selection_criteria="out_random",
num_examples=num_examples,
embedding_index=embedding_index,
metadata_index=metadata_index,
)
if not data_loader.is_valid:
@@ -197,55 +110,19 @@ async def load_data_parallel(config: Dict[str, Any]) -> List[Dict[str, Any]]:
config: Experiment configuration dictionary
Returns:
List of all task dictionaries across all users and seeds
List of task dictionaries for all users and seeds
"""
print("[DATA] Starting parallel data loading...")
# Extract parameters from config
data_path = config["data_path"]
selection_criteria = config.get("selection_criteria", "out_random")
num_examples = config.get("num_examples", 1)
sample_rate = config.get("sample_rate", 10)
# Validate selection criteria
if selection_criteria not in VALID_CRITERIA:
print(f"[WARNING] Invalid criteria '{selection_criteria}', using out_random")
selection_criteria = "out_random"
# Initialize required indices based on selection criteria
embedding_index = None
metadata_index = None
if selection_criteria == "out_similar":
embedding_path = config.get("embedding_path")
if embedding_path:
embedding_index = init_embedding_index(embedding_path)
if embedding_index is None:
print("[WARNING] out_similar requires embedding_path, falling back to out_random")
selection_criteria = "out_random"
elif selection_criteria == "out_metadata":
metadata_path = config.get("metadata_path")
weight_gender = config.get("weight_gender", 1.0)
weight_age = config.get("weight_age", 1.0)
if metadata_path:
metadata_index = init_metadata_index(
data_path=data_path,
metadata_path=metadata_path,
weight_gender=weight_gender,
weight_age=weight_age,
)
if metadata_index is None:
print("[WARNING] out_metadata requires metadata_path, falling back to out_random")
selection_criteria = "out_random"
# Collect user list from data directory
# Collect user list
user_paths = glob(os.path.join(data_path, "*"))
users = [os.path.basename(p) for p in user_paths if os.path.isdir(p)]
print(f"[DATA] Found {len(users)} users")
print(f"[DATA] Selection criteria: {selection_criteria}")
print(f"[DATA] Num examples per class: {num_examples}")
print(f"[DATA] Sample rate: 1/{sample_rate}")
@@ -266,11 +143,8 @@ async def load_data_parallel(config: Dict[str, Any]) -> List[Dict[str, Any]]:
user,
seed,
config["log_path"],
selection_criteria,
num_examples,
sample_rate,
embedding_index,
metadata_index,
)
futures.append(future)
@@ -357,12 +231,10 @@ async def run_parallel(
print(f"[RUN] Self-Consistency samples per task: {num_sc_samples}")
# Execute all tasks asynchronously
async_tasks = []
for task in tasks:
async_task = asyncio.create_task(
run_single_task(task, model_pool, num_sc_samples)
)
async_tasks.append(async_task)
async_tasks = [
asyncio.create_task(run_single_task(task, model_pool, num_sc_samples))
for task in tasks
]
results = await asyncio.gather(*async_tasks, return_exceptions=True)
@@ -417,9 +289,10 @@ def compute_statistics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
if r.get("is_correct", False):
class_correct[gt] = class_correct.get(gt, 0) + 1
class_accuracy = {}
for cls in class_total:
class_accuracy[cls] = class_correct.get(cls, 0) / class_total[cls]
class_accuracy = {
cls: class_correct.get(cls, 0) / class_total[cls]
for cls in class_total
}
# High consistency accuracy analysis
high_consistency_results = [r for r in results if r.get("consistency", 0) >= 0.8]
@@ -484,7 +357,7 @@ def save_results(
# CLI Commands
# =============================================================================
def run(config_path: str) -> None:
def main(config_path: str) -> None:
"""
Run Self-Consistency experiment.
@@ -492,7 +365,7 @@ def run(config_path: str) -> None:
config_path: Path to YAML configuration file
Example:
python -m sc.run_sc run sc/config/sleepedf_sc.yaml
python -m sc.run_sc sc/config/sleepedf_sc.yaml
"""
print(f"[MAIN] Loading config: {config_path}")
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
@@ -508,7 +381,6 @@ def run(config_path: str) -> None:
print("=" * 60)
print(f" Data path: {config.get('data_path', 'N/A')}")
print(f" Log path: {config.get('log_path', 'N/A')}")
print(f" Selection criteria: {config.get('selection_criteria', 'out_random')}")
print(f" Num ICL examples: {config.get('num_examples', 1)}")
print(f" Num seeds: {config.get('num_seeds', 1)}")
print(f" Num SC samples: {config.get('num_sc_samples', 5)}")
@@ -562,48 +434,5 @@ def run(config_path: str) -> None:
print(f"[MAIN] Results saved to: {config['log_path']}")
def test(config_path: str, max_tasks: int = 5) -> None:
"""
Run quick test with limited tasks for debugging.
Args:
config_path: Path to YAML configuration file
max_tasks: Maximum number of tasks to run
Example:
python -m sc.run_sc test sc/config/sleepedf_sc.yaml --max_tasks=5
"""
print(f"[TEST] Running quick test with max {max_tasks} tasks")
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
config["num_seeds"] = 1 # Use single seed for testing
config["log_path"] = config.get("log_path", "./test_results") + "_test"
# Load models with specified temperature
temperature = config.get("temperature", 0.0)
model_pool = load_models_with_temperature(config["models"], temperature=temperature)
# Load data (limited)
tasks = asyncio.run(load_data_parallel(config))
tasks = tasks[:max_tasks]
if len(tasks) == 0:
print("[ERROR] No valid tasks.")
return
print(f"[TEST] Running {len(tasks)} tasks...")
model_pool.warmup()
results = asyncio.run(run_parallel(tasks, model_pool, config))
# Display brief results
stats = compute_statistics(results)
print(f"\n[TEST] Accuracy: {stats.get('accuracy', 0):.4f}")
print(f"[TEST] Avg Consistency: {stats.get('avg_consistency', 0):.4f}")
if __name__ == "__main__":
Fire({
"run": run,
"test": test,
})
Fire(main)