Modularize debug logging for SC runners
This commit is contained in:
426
sc/debug_log.py
Normal file
426
sc/debug_log.py
Normal file
@@ -0,0 +1,426 @@
|
||||
def _enabled(config=None):
|
||||
if config is None:
|
||||
return True
|
||||
return config.get("debug", True)
|
||||
|
||||
|
||||
def log(message, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(message)
|
||||
|
||||
|
||||
def warn_no_examples():
|
||||
print("[WARN] No examples found for dataloader. Skipping task.")
|
||||
|
||||
|
||||
def log_queue_state_before(processed_count, user_info, ex_queue, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print(
|
||||
f"[Sample {processed_count}] {user_info} - Queue State BEFORE Processing "
|
||||
f"(Instance ID: {ex_queue.get_instance_id()}):"
|
||||
)
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(
|
||||
f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})"
|
||||
)
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def warn_no_agents(processed_count):
|
||||
print(f"[WARN] No agents added for sample {processed_count}. Skipping.")
|
||||
|
||||
|
||||
def warn_interpretation_failed(processed_count):
|
||||
print(f"[WARN] Interpretation failed for sample {processed_count}. Skipping.")
|
||||
|
||||
|
||||
def log_tracking(
|
||||
processed_count,
|
||||
user_info,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
avg_confidence_so_far,
|
||||
config=None,
|
||||
):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[TRACKING] Sample {processed_count} | {user_info}")
|
||||
print(
|
||||
f" Answer: {answer} | GT: {ground_truth} | "
|
||||
f"{'✓ CORRECT' if is_correct else '✗ WRONG'}"
|
||||
)
|
||||
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
|
||||
print(" ─────────────────────────────────────────────────────")
|
||||
print(
|
||||
f" Cumulative Accuracy: {cumulative_accuracy:.4f} "
|
||||
f"({cumulative_correct}/{processed_count + 1})"
|
||||
)
|
||||
print(
|
||||
f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}"
|
||||
)
|
||||
print(f" Avg Confidence (so far): {avg_confidence_so_far:.4f}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def log_queue_state_after(processed_count, ex_queue, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n[Sample {processed_count}] Queue State AFTER Update:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(
|
||||
f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})"
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
def warn_no_responses(processed_count):
|
||||
print(f"[WARN] No responses returned, falling back to basic update.")
|
||||
|
||||
|
||||
def log_final_queue_stats(user_info, survival_summary, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[FINAL] Queue Survival Statistics for {user_info}")
|
||||
print(f" Total Evicted Cases: {survival_summary['total_evicted']}")
|
||||
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
|
||||
print(f" Max Survival: {survival_summary['max_survival']} samples")
|
||||
print(f" Min Survival: {survival_summary['min_survival']} samples")
|
||||
print(f" Avg Usage Count: {survival_summary['avg_usage']:.2f}")
|
||||
print(f" Max Usage Count: {survival_summary['max_usage']}")
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def warn_no_user_dirs(data_path):
|
||||
print(f"[WARN] No user directories found in {data_path}")
|
||||
|
||||
|
||||
def log_found_users(users, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(
|
||||
f"[INFO] Found {len(users)} users: {users[:5]}"
|
||||
f"{'...' if len(users) > 5 else ''}"
|
||||
)
|
||||
|
||||
|
||||
def warn_skip_user_no_test_data(user):
|
||||
print(f"[WARN] Skipping user {user} - no test data available")
|
||||
|
||||
|
||||
def warn_skip_user_no_example_data(user):
|
||||
print(f"[WARN] Skipping user {user} - no example data available")
|
||||
|
||||
|
||||
def log_main_loading_config(config_path, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[MAIN] Loading config: {config_path}")
|
||||
|
||||
|
||||
def log_main_config(config, config_enabled=True):
|
||||
if not config_enabled:
|
||||
return
|
||||
print("=" * 60)
|
||||
print("SELF-CONSISTENCY EXPERIMENT CONFIGURATION")
|
||||
print("=" * 60)
|
||||
print(f" Data path: {config.get('data_path', 'N/A')}")
|
||||
print(f" Log path: {config.get('log_path', 'N/A')}")
|
||||
print(f" Num ICL examples: {config.get('num_examples', 1)}")
|
||||
print(f" Num seeds: {config.get('num_seeds', 1)}")
|
||||
print(f" Num SC samples: {config.get('num_sc_samples', 5)}")
|
||||
print(f" Temperature: {config.get('temperature', 0.0)}")
|
||||
print(f" Sample rate: 1/{config.get('sample_rate', 10)}")
|
||||
print(f" Num models: {len(config.get('models', []))}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def log_main_start(config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("[MAIN] Starting experiments...")
|
||||
|
||||
|
||||
def error_task_failed(result):
|
||||
print(f"[ERROR] Task failed with exception: {result}")
|
||||
|
||||
|
||||
def log_total_results(count, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[MAIN] Total results collected: {count}")
|
||||
|
||||
|
||||
def log_experiment_results(stats, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("\n" + "=" * 60)
|
||||
print("EXPERIMENT RESULTS")
|
||||
print("=" * 60)
|
||||
print(f" Total samples: {stats.get('total_samples', 0)}")
|
||||
print(f" Accuracy: {stats.get('accuracy', 0):.4f}")
|
||||
print(f" Avg Confidence: {stats.get('avg_confidence', 0):.4f}")
|
||||
print(f" Avg Consistency: {stats.get('avg_consistency', 0):.4f}")
|
||||
print(
|
||||
" High Consistency (>=0.8) Accuracy: "
|
||||
f"{stats.get('high_consistency_accuracy', 0):.4f}"
|
||||
)
|
||||
print(f" High Consistency Samples: {stats.get('high_consistency_samples', 0)}")
|
||||
print("\n Class-wise Accuracy:")
|
||||
for cls, acc in stats.get("class_accuracy", {}).items():
|
||||
print(f" {cls}: {acc:.4f}")
|
||||
|
||||
|
||||
def log_temporal_analysis(temporal, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
if not temporal:
|
||||
return
|
||||
print("\n" + "-" * 60)
|
||||
print(" TEMPORAL ANALYSIS (Caching Effect)")
|
||||
print("-" * 60)
|
||||
print(f" First Half Accuracy: {temporal.get('first_half_accuracy', 0):.4f}")
|
||||
print(f" Second Half Accuracy: {temporal.get('second_half_accuracy', 0):.4f}")
|
||||
improvement = temporal.get("accuracy_improvement", 0)
|
||||
improvement_sign = "+" if improvement >= 0 else ""
|
||||
print(f" Improvement: {improvement_sign}{improvement:.4f}")
|
||||
quartiles = temporal.get("quartile_accuracies", [])
|
||||
if quartiles:
|
||||
print(
|
||||
f" Quartile Accuracies: Q1={quartiles[0]:.4f}"
|
||||
+ (f", Q2={quartiles[1]:.4f}" if len(quartiles) > 1 else "")
|
||||
+ (f", Q3={quartiles[2]:.4f}" if len(quartiles) > 2 else "")
|
||||
+ (f", Q4={quartiles[3]:.4f}" if len(quartiles) > 3 else "")
|
||||
)
|
||||
print(f"\n First Half Confidence: {temporal.get('first_half_confidence', 0):.4f}")
|
||||
print(f" Second Half Confidence: {temporal.get('second_half_confidence', 0):.4f}")
|
||||
conf_improvement = temporal.get("confidence_improvement", 0)
|
||||
conf_sign = "+" if conf_improvement >= 0 else ""
|
||||
print(f" Confidence Change: {conf_sign}{conf_improvement:.4f}")
|
||||
|
||||
|
||||
def log_queue_stats(queue_stats, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
if not queue_stats:
|
||||
return
|
||||
print("\n" + "-" * 60)
|
||||
print(" QUEUE SURVIVAL STATISTICS")
|
||||
print("-" * 60)
|
||||
print(f" Total Evicted Cases: {queue_stats.get('total_evicted', 0)}")
|
||||
print(f" Avg Survival: {queue_stats.get('avg_survival', 0):.2f} samples")
|
||||
print(f" Max Survival: {queue_stats.get('max_survival', 0)} samples")
|
||||
print(f" Min Survival: {queue_stats.get('min_survival', 0)} samples")
|
||||
print(f" Avg Usage Count: {queue_stats.get('avg_usage', 0):.2f}")
|
||||
print(f" Max Usage Count: {queue_stats.get('max_usage', 0)}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def log_save_statistics(stats_path, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[SAVE] Statistics saved to: {stats_path}")
|
||||
|
||||
|
||||
def log_save_results(results_path, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[SAVE] Results saved to: {results_path}")
|
||||
|
||||
|
||||
def log_save_config(config_path, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[SAVE] Config saved to: {config_path}")
|
||||
|
||||
|
||||
def log_results_saved(log_path, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[MAIN] Results saved to: {log_path}")
|
||||
|
||||
|
||||
def log_policy_experiment_header(policy_label, user_id, shuffle_seed, total_samples, queue_size, config=None, policy_note=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'='*80}")
|
||||
print(f"{policy_label} QUEUE POLICY EXPERIMENT")
|
||||
print(f"User: {user_id} | Shuffle Seed: {shuffle_seed}")
|
||||
print(f"Total samples: {total_samples} | Queue size: {queue_size}")
|
||||
if policy_note:
|
||||
print(policy_note)
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
|
||||
def log_policy_queue_state(policy_label, processed_count, user_id, ex_queue, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[Sample {processed_count}] User {user_id} | {policy_label} Policy")
|
||||
print("Queue State BEFORE Processing:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def log_policy_result(policy_label, processed_count, answer, ground_truth, is_correct, avg_confidence, consistency, cumulative_accuracy, cumulative_correct, window_accuracy, recent_results, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[RESULT] Sample {processed_count} | {policy_label} Policy")
|
||||
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
|
||||
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
|
||||
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
|
||||
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def log_confidence_map(responses, confidence_map, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("\n[CONFIDENCE MAP] Per-agent confidence scores:")
|
||||
for idx, conf in sorted(confidence_map.items()):
|
||||
ans = responses[idx].get("ANSWER", "?")
|
||||
print(f" Queue[{idx}]: answer={ans}, confidence={conf:.4f}")
|
||||
|
||||
|
||||
def log_consistency_map(responses, consistency_map, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("\n[CONSISTENCY MAP] Per-agent consistency scores:")
|
||||
for idx, cons in sorted(consistency_map.items()):
|
||||
ans = responses[idx].get("ANSWER", "?")
|
||||
print(f" Queue[{idx}]: answer={ans}, consistency={cons:.4f}")
|
||||
|
||||
|
||||
def log_policy_queue_state_after(policy_label, processed_count, ex_queue, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n[Sample {processed_count}] Queue State AFTER {policy_label} Update:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
|
||||
|
||||
|
||||
def log_final_policy_survival(policy_label, user_id, shuffle_seed, survival_summary, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[FINAL] Queue Survival Statistics - {policy_label} Policy")
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Evicted: {survival_summary['total_evicted']}")
|
||||
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
|
||||
print(f" Avg Usage: {survival_summary['avg_usage']:.2f}")
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def log_policy_main_header(policy_label, user_id, shuffle_seed, queue_size, num_models, log_path, config=None, policy_note=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("=" * 80)
|
||||
print(f"{policy_label} QUEUE POLICY EXPERIMENT")
|
||||
print("=" * 80)
|
||||
print(f" User ID: {user_id}")
|
||||
print(f" Shuffle Seed: {shuffle_seed}")
|
||||
print(f" Queue Size: {queue_size}")
|
||||
print(f" SC Samples (Agents): {num_models}")
|
||||
print(f" Log Path: {log_path}")
|
||||
if policy_note:
|
||||
print(policy_note)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def log_policy_loading_data(user_id, shuffle_seed, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n[MAIN] Loading shuffled data for user {user_id}, seed {shuffle_seed}...")
|
||||
|
||||
|
||||
def log_policy_loading_models(config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("\n[MAIN] Loading models...")
|
||||
|
||||
|
||||
def log_policy_start(config=None, label="experiment"):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n[MAIN] Starting {label}...")
|
||||
|
||||
|
||||
def log_policy_complete_summary(policy_label, user_id, shuffle_seed, stats, stage_accuracy, stage_counts, temporal, config=None, expected_note=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print("\n" + "=" * 80)
|
||||
print(f"EXPERIMENT COMPLETE - {policy_label} POLICY")
|
||||
print("=" * 80)
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Samples: {stats.get('total_samples', 0)}")
|
||||
print(f" Overall Accuracy: {stats.get('accuracy', 0):.4f}")
|
||||
print(f" Macro F1: {stats.get('macro_f1', 0):.4f}")
|
||||
print("\n Per-Stage Accuracy:")
|
||||
for stage, acc in stage_accuracy.items():
|
||||
count = stage_counts.get(stage, 0)
|
||||
print(f" {stage}: {acc:.4f} (n={count})")
|
||||
print("\n Temporal Analysis:")
|
||||
print(f" First Half: {temporal.get('first_half_accuracy', 0):.4f}")
|
||||
print(f" Second Half: {temporal.get('second_half_accuracy', 0):.4f}")
|
||||
print(f" Improvement: {temporal.get('improvement', 0):+.4f}")
|
||||
if expected_note:
|
||||
print(expected_note)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def log_queue_random_stats(user_id, shuffle_seed, sampler_stats, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print("[FINAL] Queue Random Statistics")
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Steps: {sampler_stats['total_steps']}")
|
||||
print(f" Total Refreshed: {sampler_stats['total_refreshed']} example sets")
|
||||
print(f" Avg Refresh per Step: {sampler_stats['avg_refresh_per_step']} (always full)")
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def log_queue_random_sampler_init(example_count, classes, queue_size, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"[QueueRandomSampler] Initialized with {example_count} examples")
|
||||
print(f" Classes: {classes}")
|
||||
print(f" Queue size: {queue_size}")
|
||||
print(" Policy: ALL elements refreshed every step")
|
||||
|
||||
|
||||
def log_queue_random_queue_state(processed_count, user_id, queue_sampler, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[Sample {processed_count}] User {user_id} | QUEUE RANDOM Policy")
|
||||
print("Queue State (ALL FRESH RANDOM samples):")
|
||||
for idx, ex_idcs in enumerate(queue_sampler):
|
||||
print(f" [{idx}] Example indices: {ex_idcs}")
|
||||
print(f"{'#'*60}\n")
|
||||
|
||||
|
||||
def log_queue_random_result(processed_count, answer, ground_truth, is_correct, avg_confidence, consistency, cumulative_accuracy, cumulative_correct, window_accuracy, recent_results, config=None):
|
||||
if not _enabled(config):
|
||||
return
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[RESULT] Sample {processed_count} | QUEUE RANDOM Policy")
|
||||
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
|
||||
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
|
||||
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
|
||||
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
|
||||
print(" [NOTE] Queue will be FULLY REFRESHED for next sample")
|
||||
print(f"{'='*60}\n")
|
||||
@@ -32,6 +32,9 @@ from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
log = debug_log.log
|
||||
|
||||
|
||||
async def run_confidence_experiment(
|
||||
@@ -65,7 +68,7 @@ async def run_confidence_experiment(
|
||||
|
||||
example_dataset = dataloader.get_examples()
|
||||
if len(example_dataset) == 0:
|
||||
print(f"[ERROR] No examples found for user {user_id}")
|
||||
log(f"[ERROR] No examples found for user {user_id}")
|
||||
return []
|
||||
|
||||
# Build class_indices for Queue initialization
|
||||
@@ -91,22 +94,19 @@ async def run_confidence_experiment(
|
||||
all_predictions = []
|
||||
all_ground_truths = []
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"CONFIDENCE-BASED QUEUE POLICY EXPERIMENT")
|
||||
print(f"User: {user_id} | Shuffle Seed: {shuffle_seed}")
|
||||
print(f"Total samples: {len(dataloader)} | Queue size: {queue_size}")
|
||||
print(f"{'='*80}\n")
|
||||
debug_log.log_policy_experiment_header(
|
||||
"CONFIDENCE-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
len(dataloader),
|
||||
queue_size,
|
||||
)
|
||||
|
||||
for processed_count, sample in enumerate(dataloader):
|
||||
ex_queue.set_current_time(processed_count)
|
||||
|
||||
# Log queue state before processing
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[Sample {processed_count}] User {user_id} | CONFIDENCE Policy")
|
||||
print(f"Queue State BEFORE Processing:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
|
||||
print(f"{'#'*60}\n")
|
||||
debug_log.log_policy_queue_state("CONFIDENCE", processed_count, user_id, ex_queue)
|
||||
|
||||
# Create agent pool
|
||||
agent_pool = AgentPool(log_path=config["log_path"])
|
||||
@@ -128,26 +128,26 @@ async def run_confidence_experiment(
|
||||
)
|
||||
agent_pool.add_agent(agent)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to create agents: {e}")
|
||||
log(f"[ERROR] Failed to create agents: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if len(agent_pool.agents) == 0:
|
||||
print(f"[WARN] No agents created for sample {processed_count}")
|
||||
log(f"[WARN] No agents created for sample {processed_count}")
|
||||
continue
|
||||
|
||||
# Run parallel interpretation
|
||||
try:
|
||||
interpretation_result = await agent_pool.run_parallel_interpretation()
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Interpretation failed: {e}")
|
||||
log(f"[ERROR] Interpretation failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if interpretation_result is None:
|
||||
print(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
log(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
continue
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
|
||||
@@ -171,13 +171,19 @@ async def run_confidence_experiment(
|
||||
all_ground_truths.append(ground_truth)
|
||||
|
||||
# Performance logging
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[RESULT] Sample {processed_count} | CONFIDENCE Policy")
|
||||
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
|
||||
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
|
||||
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
|
||||
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
|
||||
print(f"{'='*60}\n")
|
||||
debug_log.log_policy_result(
|
||||
"CONFIDENCE",
|
||||
processed_count,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
)
|
||||
|
||||
# CONFIDENCE-BASED Queue Update
|
||||
if responses:
|
||||
@@ -186,18 +192,13 @@ async def run_confidence_experiment(
|
||||
for idx, response in responses.items():
|
||||
confidence_map[idx] = response.get("CONFIDENCE", 0.0)
|
||||
|
||||
print(f"\n[CONFIDENCE MAP] Per-agent confidence scores:")
|
||||
for idx, conf in sorted(confidence_map.items()):
|
||||
ans = responses[idx].get("ANSWER", "?")
|
||||
print(f" Queue[{idx}]: answer={ans}, confidence={conf:.4f}")
|
||||
debug_log.log_confidence_map(responses, confidence_map)
|
||||
|
||||
# Update queue by confidence (evict lowest, add new random)
|
||||
ex_queue.update_by_confidence(confidence_map)
|
||||
ex_queue.increment_usage(list(responses.keys()))
|
||||
|
||||
print(f"\n[Sample {processed_count}] Queue State AFTER Confidence Update:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
|
||||
debug_log.log_policy_queue_state_after("Confidence", processed_count, ex_queue)
|
||||
|
||||
# Store result
|
||||
result = {
|
||||
@@ -219,13 +220,7 @@ async def run_confidence_experiment(
|
||||
|
||||
# Final statistics
|
||||
survival_summary = ex_queue.get_survival_summary()
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[FINAL] Queue Survival Statistics - CONFIDENCE Policy")
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Evicted: {survival_summary['total_evicted']}")
|
||||
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
|
||||
print(f" Avg Usage: {survival_summary['avg_usage']:.2f}")
|
||||
print(f"{'#'*60}\n")
|
||||
debug_log.log_final_policy_survival("CONFIDENCE", user_id, shuffle_seed, survival_summary)
|
||||
|
||||
if results:
|
||||
results[-1]["queue_survival_stats"] = survival_summary
|
||||
@@ -361,19 +356,19 @@ def save_results(
|
||||
stats_path = os.path.join(output_dir, "statistics.json")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||||
print(f"[SAVE] Statistics: {stats_path}")
|
||||
log(f"[SAVE] Statistics: {stats_path}")
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(output_dir, "results.json")
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"[SAVE] Results: {results_path}")
|
||||
log(f"[SAVE] Results: {results_path}")
|
||||
|
||||
# Save config
|
||||
config_path = os.path.join(output_dir, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
print(f"[SAVE] Config: {config_path}")
|
||||
log(f"[SAVE] Config: {config_path}")
|
||||
|
||||
|
||||
def main(
|
||||
@@ -393,7 +388,7 @@ def main(
|
||||
python -m sc.run_confidence --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_confidence --user_id=10 --shuffle_seed=123
|
||||
"""
|
||||
print(f"[MAIN] Loading config: {config_path}")
|
||||
log(f"[MAIN] Loading config: {config_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
@@ -407,18 +402,17 @@ def main(
|
||||
os.makedirs(config["log_path"], exist_ok=True)
|
||||
|
||||
# Print experiment info
|
||||
print("=" * 80)
|
||||
print("CONFIDENCE-BASED QUEUE POLICY EXPERIMENT")
|
||||
print("=" * 80)
|
||||
print(f" User ID: {user_id}")
|
||||
print(f" Shuffle Seed: {shuffle_seed}")
|
||||
print(f" Queue Size: {config.get('queue_size', 5)}")
|
||||
print(f" SC Samples (Agents): {len(config.get('models', []))}")
|
||||
print(f" Log Path: {config['log_path']}")
|
||||
print("=" * 80)
|
||||
debug_log.log_policy_main_header(
|
||||
"CONFIDENCE-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
config.get("queue_size", 5),
|
||||
len(config.get("models", [])),
|
||||
config["log_path"],
|
||||
)
|
||||
|
||||
# Load shuffled data
|
||||
print(f"\n[MAIN] Loading shuffled data for user {user_id}, seed {shuffle_seed}...")
|
||||
debug_log.log_policy_loading_data(user_id, shuffle_seed)
|
||||
dataloader = ShuffledDataLoader(
|
||||
data_path=config["data_path"],
|
||||
user_id=user_id,
|
||||
@@ -427,7 +421,7 @@ def main(
|
||||
)
|
||||
|
||||
# Load models
|
||||
print(f"\n[MAIN] Loading models...")
|
||||
debug_log.log_policy_loading_models()
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
@@ -435,7 +429,7 @@ def main(
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
print(f"\n[MAIN] Starting experiment...")
|
||||
debug_log.log_policy_start(label="experiment")
|
||||
results = asyncio.run(run_confidence_experiment(
|
||||
dataloader=dataloader,
|
||||
model_pool=model_pool,
|
||||
@@ -449,27 +443,20 @@ def main(
|
||||
stats = compute_statistics(results, stages)
|
||||
|
||||
# Print final summary
|
||||
print("\n" + "=" * 80)
|
||||
print("EXPERIMENT COMPLETE - CONFIDENCE POLICY")
|
||||
print("=" * 80)
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Samples: {stats.get('total_samples', 0)}")
|
||||
print(f" Overall Accuracy: {stats.get('accuracy', 0):.4f}")
|
||||
print(f" Macro F1: {stats.get('macro_f1', 0):.4f}")
|
||||
print(f"\n Per-Stage Accuracy:")
|
||||
for stage, acc in stats.get("stage_accuracy", {}).items():
|
||||
count = stats.get("stage_sample_counts", {}).get(stage, 0)
|
||||
print(f" {stage}: {acc:.4f} (n={count})")
|
||||
print(f"\n Temporal Analysis:")
|
||||
temporal = stats.get("temporal_analysis", {})
|
||||
print(f" First Half: {temporal.get('first_half_accuracy', 0):.4f}")
|
||||
print(f" Second Half: {temporal.get('second_half_accuracy', 0):.4f}")
|
||||
print(f" Improvement: {temporal.get('improvement', 0):+.4f}")
|
||||
print("=" * 80)
|
||||
debug_log.log_policy_complete_summary(
|
||||
"CONFIDENCE",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
stats,
|
||||
stats.get("stage_accuracy", {}),
|
||||
stats.get("stage_sample_counts", {}),
|
||||
temporal,
|
||||
)
|
||||
|
||||
# Save results
|
||||
save_results(results, stats, config, user_id, shuffle_seed)
|
||||
print(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
log(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -34,6 +34,9 @@ from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
log = debug_log.log
|
||||
|
||||
|
||||
async def run_consistency_experiment(
|
||||
@@ -69,7 +72,7 @@ async def run_consistency_experiment(
|
||||
|
||||
example_dataset = dataloader.get_examples()
|
||||
if len(example_dataset) == 0:
|
||||
print(f"[ERROR] No examples found for user {user_id}")
|
||||
log(f"[ERROR] No examples found for user {user_id}")
|
||||
return []
|
||||
|
||||
# Build class_indices for Queue initialization
|
||||
@@ -95,22 +98,19 @@ async def run_consistency_experiment(
|
||||
all_predictions = []
|
||||
all_ground_truths = []
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"CONSISTENCY-BASED QUEUE POLICY EXPERIMENT")
|
||||
print(f"User: {user_id} | Shuffle Seed: {shuffle_seed}")
|
||||
print(f"Total samples: {len(dataloader)} | Queue size: {queue_size}")
|
||||
print(f"{'='*80}\n")
|
||||
debug_log.log_policy_experiment_header(
|
||||
"CONSISTENCY-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
len(dataloader),
|
||||
queue_size,
|
||||
)
|
||||
|
||||
for processed_count, sample in enumerate(dataloader):
|
||||
ex_queue.set_current_time(processed_count)
|
||||
|
||||
# Log queue state before processing
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[Sample {processed_count}] User {user_id} | CONSISTENCY Policy")
|
||||
print(f"Queue State BEFORE Processing:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
|
||||
print(f"{'#'*60}\n")
|
||||
debug_log.log_policy_queue_state("CONSISTENCY", processed_count, user_id, ex_queue)
|
||||
|
||||
# Create agent pool
|
||||
agent_pool = AgentPool(log_path=config["log_path"])
|
||||
@@ -132,26 +132,26 @@ async def run_consistency_experiment(
|
||||
)
|
||||
agent_pool.add_agent(agent)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to create agents: {e}")
|
||||
log(f"[ERROR] Failed to create agents: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if len(agent_pool.agents) == 0:
|
||||
print(f"[WARN] No agents created for sample {processed_count}")
|
||||
log(f"[WARN] No agents created for sample {processed_count}")
|
||||
continue
|
||||
|
||||
# Run parallel interpretation
|
||||
try:
|
||||
interpretation_result = await agent_pool.run_parallel_interpretation()
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Interpretation failed: {e}")
|
||||
log(f"[ERROR] Interpretation failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if interpretation_result is None:
|
||||
print(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
log(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
continue
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
|
||||
@@ -175,13 +175,19 @@ async def run_consistency_experiment(
|
||||
all_ground_truths.append(ground_truth)
|
||||
|
||||
# Performance logging
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[RESULT] Sample {processed_count} | CONSISTENCY Policy")
|
||||
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
|
||||
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
|
||||
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
|
||||
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
|
||||
print(f"{'='*60}\n")
|
||||
debug_log.log_policy_result(
|
||||
"CONSISTENCY",
|
||||
processed_count,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
)
|
||||
|
||||
# CONSISTENCY-BASED Queue Update
|
||||
if responses:
|
||||
@@ -196,18 +202,13 @@ async def run_consistency_experiment(
|
||||
agent_consistency = agreement_count / len(all_answers) if all_answers else 0
|
||||
consistency_map[idx] = agent_consistency
|
||||
|
||||
print(f"\n[CONSISTENCY MAP] Per-agent consistency scores:")
|
||||
for idx, cons in sorted(consistency_map.items()):
|
||||
ans = responses[idx].get("ANSWER", "?")
|
||||
print(f" Queue[{idx}]: answer={ans}, consistency={cons:.4f}")
|
||||
debug_log.log_consistency_map(responses, consistency_map)
|
||||
|
||||
# Update queue by consistency (reuses confidence method with consistency scores)
|
||||
ex_queue.update_by_confidence(consistency_map)
|
||||
ex_queue.increment_usage(list(responses.keys()))
|
||||
|
||||
print(f"\n[Sample {processed_count}] Queue State AFTER Consistency Update:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
|
||||
debug_log.log_policy_queue_state_after("Consistency", processed_count, ex_queue)
|
||||
|
||||
# Store result
|
||||
result = {
|
||||
@@ -229,13 +230,7 @@ async def run_consistency_experiment(
|
||||
|
||||
# Final statistics
|
||||
survival_summary = ex_queue.get_survival_summary()
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[FINAL] Queue Survival Statistics - CONSISTENCY Policy")
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Evicted: {survival_summary['total_evicted']}")
|
||||
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
|
||||
print(f" Avg Usage: {survival_summary['avg_usage']:.2f}")
|
||||
print(f"{'#'*60}\n")
|
||||
debug_log.log_final_policy_survival("CONSISTENCY", user_id, shuffle_seed, survival_summary)
|
||||
|
||||
if results:
|
||||
results[-1]["queue_survival_stats"] = survival_summary
|
||||
@@ -371,19 +366,19 @@ def save_results(
|
||||
stats_path = os.path.join(output_dir, "statistics.json")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||||
print(f"[SAVE] Statistics: {stats_path}")
|
||||
log(f"[SAVE] Statistics: {stats_path}")
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(output_dir, "results.json")
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"[SAVE] Results: {results_path}")
|
||||
log(f"[SAVE] Results: {results_path}")
|
||||
|
||||
# Save config
|
||||
config_path = os.path.join(output_dir, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
print(f"[SAVE] Config: {config_path}")
|
||||
log(f"[SAVE] Config: {config_path}")
|
||||
|
||||
|
||||
def main(
|
||||
@@ -403,7 +398,7 @@ def main(
|
||||
python -m sc.run_consistency --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_consistency --user_id=10 --shuffle_seed=123
|
||||
"""
|
||||
print(f"[MAIN] Loading config: {config_path}")
|
||||
log(f"[MAIN] Loading config: {config_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
@@ -417,18 +412,17 @@ def main(
|
||||
os.makedirs(config["log_path"], exist_ok=True)
|
||||
|
||||
# Print experiment info
|
||||
print("=" * 80)
|
||||
print("CONSISTENCY-BASED QUEUE POLICY EXPERIMENT")
|
||||
print("=" * 80)
|
||||
print(f" User ID: {user_id}")
|
||||
print(f" Shuffle Seed: {shuffle_seed}")
|
||||
print(f" Queue Size: {config.get('queue_size', 5)}")
|
||||
print(f" SC Samples (Agents): {len(config.get('models', []))}")
|
||||
print(f" Log Path: {config['log_path']}")
|
||||
print("=" * 80)
|
||||
debug_log.log_policy_main_header(
|
||||
"CONSISTENCY-BASED",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
config.get("queue_size", 5),
|
||||
len(config.get("models", [])),
|
||||
config["log_path"],
|
||||
)
|
||||
|
||||
# Load shuffled data
|
||||
print(f"\n[MAIN] Loading shuffled data for user {user_id}, seed {shuffle_seed}...")
|
||||
debug_log.log_policy_loading_data(user_id, shuffle_seed)
|
||||
dataloader = ShuffledDataLoader(
|
||||
data_path=config["data_path"],
|
||||
user_id=user_id,
|
||||
@@ -437,7 +431,7 @@ def main(
|
||||
)
|
||||
|
||||
# Load models
|
||||
print(f"\n[MAIN] Loading models...")
|
||||
debug_log.log_policy_loading_models()
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
@@ -445,7 +439,7 @@ def main(
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
print(f"\n[MAIN] Starting experiment...")
|
||||
debug_log.log_policy_start(label="experiment")
|
||||
results = asyncio.run(run_consistency_experiment(
|
||||
dataloader=dataloader,
|
||||
model_pool=model_pool,
|
||||
@@ -459,27 +453,20 @@ def main(
|
||||
stats = compute_statistics(results, stages)
|
||||
|
||||
# Print final summary
|
||||
print("\n" + "=" * 80)
|
||||
print("EXPERIMENT COMPLETE - CONSISTENCY POLICY")
|
||||
print("=" * 80)
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Samples: {stats.get('total_samples', 0)}")
|
||||
print(f" Overall Accuracy: {stats.get('accuracy', 0):.4f}")
|
||||
print(f" Macro F1: {stats.get('macro_f1', 0):.4f}")
|
||||
print(f"\n Per-Stage Accuracy:")
|
||||
for stage, acc in stats.get("stage_accuracy", {}).items():
|
||||
count = stats.get("stage_sample_counts", {}).get(stage, 0)
|
||||
print(f" {stage}: {acc:.4f} (n={count})")
|
||||
print(f"\n Temporal Analysis:")
|
||||
temporal = stats.get("temporal_analysis", {})
|
||||
print(f" First Half: {temporal.get('first_half_accuracy', 0):.4f}")
|
||||
print(f" Second Half: {temporal.get('second_half_accuracy', 0):.4f}")
|
||||
print(f" Improvement: {temporal.get('improvement', 0):+.4f}")
|
||||
print("=" * 80)
|
||||
debug_log.log_policy_complete_summary(
|
||||
"CONSISTENCY",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
stats,
|
||||
stats.get("stage_accuracy", {}),
|
||||
stats.get("stage_sample_counts", {}),
|
||||
temporal,
|
||||
)
|
||||
|
||||
# Save results
|
||||
save_results(results, stats, config, user_id, shuffle_seed)
|
||||
print(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
log(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
144
sc/run_sc.py
144
sc/run_sc.py
@@ -31,6 +31,7 @@ from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.queue import Queue
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
|
||||
async def run_single_task(
|
||||
@@ -55,7 +56,7 @@ async def run_single_task(
|
||||
|
||||
example_dataset = dataloader.get_examples()
|
||||
if len(example_dataset) == 0:
|
||||
print(f"[WARN] No examples found for dataloader. Skipping task.")
|
||||
debug_log.warn_no_examples()
|
||||
return []
|
||||
|
||||
# Build class_indices: Dict[str, List[int]] for Queue initialization
|
||||
@@ -80,12 +81,8 @@ async def run_single_task(
|
||||
|
||||
ex_queue.set_current_time(processed_count) # Track current sample index for survival time
|
||||
|
||||
print(f"\n{'#'*60}")
|
||||
user_info = f"User: {user_id}" if user_id else "Unknown User"
|
||||
print(f"[Sample {processed_count}] {user_info} - Queue State BEFORE Processing (Instance ID: {ex_queue.get_instance_id()}):")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
|
||||
print(f"{'#'*60}\n")
|
||||
debug_log.log_queue_state_before(processed_count, user_info, ex_queue, config)
|
||||
|
||||
agent_pool = AgentPool(log_path=config["log_path"])
|
||||
for queue_idx, ex_idcs in enumerate(ex_queue):
|
||||
@@ -108,14 +105,14 @@ async def run_single_task(
|
||||
|
||||
# Check if any agents were added
|
||||
if len(agent_pool.agents) == 0:
|
||||
print(f"[WARN] No agents added for sample {processed_count}. Skipping.")
|
||||
debug_log.warn_no_agents(processed_count)
|
||||
continue
|
||||
|
||||
interpretation_result = await agent_pool.run_parallel_interpretation()
|
||||
|
||||
# Handle case where interpretation failed
|
||||
if interpretation_result is None:
|
||||
print(f"[WARN] Interpretation failed for sample {processed_count}. Skipping.")
|
||||
debug_log.warn_interpretation_failed(processed_count)
|
||||
continue
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
|
||||
@@ -135,16 +132,21 @@ async def run_single_task(
|
||||
confidence_history.append(avg_confidence)
|
||||
avg_confidence_so_far = sum(confidence_history) / len(confidence_history)
|
||||
|
||||
# Debug: Performance tracking over time
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[TRACKING] Sample {processed_count} | {user_info}")
|
||||
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
|
||||
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
|
||||
print(f" ─────────────────────────────────────────────────────")
|
||||
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
|
||||
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
|
||||
print(f" Avg Confidence (so far): {avg_confidence_so_far:.4f}")
|
||||
print(f"{'='*60}\n")
|
||||
debug_log.log_tracking(
|
||||
processed_count,
|
||||
user_info,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
avg_confidence_so_far,
|
||||
config,
|
||||
)
|
||||
|
||||
# Update queue based on Confidence (Priority Queue)
|
||||
if responses:
|
||||
@@ -153,13 +155,9 @@ async def run_single_task(
|
||||
ex_queue.update_by_confidence(confidence_map)
|
||||
ex_queue.increment_usage(list(responses.keys()))
|
||||
|
||||
# Debug: Queue state after update
|
||||
print(f"\n[Sample {processed_count}] Queue State AFTER Update:")
|
||||
for idx, case_info in enumerate(ex_queue.get_state_with_stats()):
|
||||
print(f" [{idx}] {case_info['case']} (age={case_info['age']}, used={case_info['usage']})")
|
||||
print()
|
||||
debug_log.log_queue_state_after(processed_count, ex_queue, config)
|
||||
elif queue_idcs:
|
||||
print(f"[WARN] No responses returned, falling back to basic update.")
|
||||
debug_log.warn_no_responses(processed_count)
|
||||
|
||||
result = {
|
||||
"sample_idx": processed_count,
|
||||
@@ -177,15 +175,7 @@ async def run_single_task(
|
||||
|
||||
# Final Queue survival statistics
|
||||
survival_summary = ex_queue.get_survival_summary()
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[FINAL] Queue Survival Statistics for {user_info}")
|
||||
print(f" Total Evicted Cases: {survival_summary['total_evicted']}")
|
||||
print(f" Avg Survival: {survival_summary['avg_survival']:.2f} samples")
|
||||
print(f" Max Survival: {survival_summary['max_survival']} samples")
|
||||
print(f" Min Survival: {survival_summary['min_survival']} samples")
|
||||
print(f" Avg Usage Count: {survival_summary['avg_usage']:.2f}")
|
||||
print(f" Max Usage Count: {survival_summary['max_usage']}")
|
||||
print(f"{'#'*60}\n")
|
||||
debug_log.log_final_queue_stats(user_info, survival_summary, config)
|
||||
|
||||
# Add Queue statistics to results (in the last result)
|
||||
if results:
|
||||
@@ -215,9 +205,9 @@ async def run_parallel(
|
||||
# Filter to only include directories (exclude files like info.json)
|
||||
users = [os.path.basename(p) for p in user_paths if os.path.isdir(p) and os.path.basename(p) != "info.json"]
|
||||
if not users:
|
||||
print(f"[WARN] No user directories found in {data_path}")
|
||||
debug_log.warn_no_user_dirs(data_path)
|
||||
return []
|
||||
print(f"[INFO] Found {len(users)} users: {users[:5]}{'...' if len(users) > 5 else ''}")
|
||||
debug_log.log_found_users(users, config)
|
||||
seeds = range(config.get("num_seeds", 1))
|
||||
|
||||
tasks = []
|
||||
@@ -233,10 +223,10 @@ async def run_parallel(
|
||||
)
|
||||
# Check if dataloader was properly initialized
|
||||
if not hasattr(dataloader, 'test_dataset') or len(dataloader) == 0:
|
||||
print(f"[WARN] Skipping user {user} - no test data available")
|
||||
debug_log.warn_skip_user_no_test_data(user)
|
||||
continue
|
||||
if len(dataloader.get_examples()) == 0:
|
||||
print(f"[WARN] Skipping user {user} - no example data available")
|
||||
debug_log.warn_skip_user_no_example_data(user)
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(
|
||||
@@ -388,7 +378,7 @@ def save_results(
|
||||
stats_path = os.path.join(log_path, "statistics.json")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||||
print(f"[SAVE] Statistics saved to: {stats_path}")
|
||||
debug_log.log_save_statistics(stats_path, config)
|
||||
|
||||
# Save all results
|
||||
results_to_save = []
|
||||
@@ -399,13 +389,13 @@ def save_results(
|
||||
results_path = os.path.join(log_path, "all_results.json")
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results_to_save, f, indent=2, ensure_ascii=False)
|
||||
print(f"[SAVE] Results saved to: {results_path}")
|
||||
debug_log.log_save_results(results_path, config)
|
||||
|
||||
# Save configuration for reproducibility
|
||||
config_path = os.path.join(log_path, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
print(f"[SAVE] Config saved to: {config_path}")
|
||||
debug_log.log_save_config(config_path, config)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -422,7 +412,7 @@ def main(config_path: str) -> None:
|
||||
Example:
|
||||
python -m sc.run_sc sc/config/sleepedf_sc.yaml
|
||||
"""
|
||||
print(f"[MAIN] Loading config: {config_path}")
|
||||
debug_log.log_main_loading_config(config_path)
|
||||
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||
|
||||
# Add timestamp to log path for unique experiment runs
|
||||
@@ -431,18 +421,7 @@ def main(config_path: str) -> None:
|
||||
config["log_path"] = f"{config['log_path']}_{timestamp}"
|
||||
|
||||
# Print experiment configuration
|
||||
print("=" * 60)
|
||||
print("SELF-CONSISTENCY EXPERIMENT CONFIGURATION")
|
||||
print("=" * 60)
|
||||
print(f" Data path: {config.get('data_path', 'N/A')}")
|
||||
print(f" Log path: {config.get('log_path', 'N/A')}")
|
||||
print(f" Num ICL examples: {config.get('num_examples', 1)}")
|
||||
print(f" Num seeds: {config.get('num_seeds', 1)}")
|
||||
print(f" Num SC samples: {config.get('num_sc_samples', 5)}")
|
||||
print(f" Temperature: {config.get('temperature', 0.0)}")
|
||||
print(f" Sample rate: 1/{config.get('sample_rate', 10)}")
|
||||
print(f" Num models: {len(config.get('models', []))}")
|
||||
print("=" * 60)
|
||||
debug_log.log_main_config(config, config.get("debug", True))
|
||||
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
@@ -451,7 +430,7 @@ def main(config_path: str) -> None:
|
||||
)
|
||||
|
||||
# Run experiments
|
||||
print("[MAIN] Starting experiments...")
|
||||
debug_log.log_main_start(config)
|
||||
all_results = asyncio.run(run_parallel(config, model_pool))
|
||||
|
||||
# Flatten results: run_parallel returns list of lists (one per user/seed)
|
||||
@@ -459,74 +438,29 @@ def main(config_path: str) -> None:
|
||||
flattened_results = []
|
||||
for result in all_results:
|
||||
if isinstance(result, Exception):
|
||||
print(f"[ERROR] Task failed with exception: {result}")
|
||||
debug_log.error_task_failed(result)
|
||||
continue
|
||||
if isinstance(result, list):
|
||||
flattened_results.extend(result)
|
||||
else:
|
||||
flattened_results.append(result)
|
||||
|
||||
print(f"[MAIN] Total results collected: {len(flattened_results)}")
|
||||
debug_log.log_total_results(len(flattened_results), config)
|
||||
|
||||
# Compute and display statistics
|
||||
print("\n" + "=" * 60)
|
||||
print("EXPERIMENT RESULTS")
|
||||
print("=" * 60)
|
||||
stats = compute_statistics(flattened_results)
|
||||
|
||||
print(f" Total samples: {stats.get('total_samples', 0)}")
|
||||
print(f" Accuracy: {stats.get('accuracy', 0):.4f}")
|
||||
print(f" Avg Confidence: {stats.get('avg_confidence', 0):.4f}")
|
||||
print(f" Avg Consistency: {stats.get('avg_consistency', 0):.4f}")
|
||||
print(f" High Consistency (>=0.8) Accuracy: {stats.get('high_consistency_accuracy', 0):.4f}")
|
||||
print(f" High Consistency Samples: {stats.get('high_consistency_samples', 0)}")
|
||||
|
||||
print("\n Class-wise Accuracy:")
|
||||
for cls, acc in stats.get("class_accuracy", {}).items():
|
||||
print(f" {cls}: {acc:.4f}")
|
||||
debug_log.log_experiment_results(stats, config)
|
||||
|
||||
# Time-based analysis output
|
||||
temporal = stats.get("temporal_analysis", {})
|
||||
if temporal:
|
||||
print("\n" + "-" * 60)
|
||||
print(" TEMPORAL ANALYSIS (Caching Effect)")
|
||||
print("-" * 60)
|
||||
print(f" First Half Accuracy: {temporal.get('first_half_accuracy', 0):.4f}")
|
||||
print(f" Second Half Accuracy: {temporal.get('second_half_accuracy', 0):.4f}")
|
||||
improvement = temporal.get('accuracy_improvement', 0)
|
||||
improvement_sign = "+" if improvement >= 0 else ""
|
||||
print(f" Improvement: {improvement_sign}{improvement:.4f}")
|
||||
|
||||
quartiles = temporal.get('quartile_accuracies', [])
|
||||
if quartiles:
|
||||
print(f" Quartile Accuracies: Q1={quartiles[0]:.4f}" +
|
||||
(f", Q2={quartiles[1]:.4f}" if len(quartiles) > 1 else "") +
|
||||
(f", Q3={quartiles[2]:.4f}" if len(quartiles) > 2 else "") +
|
||||
(f", Q4={quartiles[3]:.4f}" if len(quartiles) > 3 else ""))
|
||||
|
||||
print(f"\n First Half Confidence: {temporal.get('first_half_confidence', 0):.4f}")
|
||||
print(f" Second Half Confidence: {temporal.get('second_half_confidence', 0):.4f}")
|
||||
conf_improvement = temporal.get('confidence_improvement', 0)
|
||||
conf_sign = "+" if conf_improvement >= 0 else ""
|
||||
print(f" Confidence Change: {conf_sign}{conf_improvement:.4f}")
|
||||
debug_log.log_temporal_analysis(temporal, config)
|
||||
|
||||
queue_stats = stats.get("queue_stats", {})
|
||||
if queue_stats:
|
||||
print("\n" + "-" * 60)
|
||||
print(" QUEUE SURVIVAL STATISTICS")
|
||||
print("-" * 60)
|
||||
print(f" Total Evicted Cases: {queue_stats.get('total_evicted', 0)}")
|
||||
print(f" Avg Survival: {queue_stats.get('avg_survival', 0):.2f} samples")
|
||||
print(f" Max Survival: {queue_stats.get('max_survival', 0)} samples")
|
||||
print(f" Min Survival: {queue_stats.get('min_survival', 0)} samples")
|
||||
print(f" Avg Usage Count: {queue_stats.get('avg_usage', 0):.2f}")
|
||||
print(f" Max Usage Count: {queue_stats.get('max_usage', 0)}")
|
||||
|
||||
print("=" * 60)
|
||||
debug_log.log_queue_stats(queue_stats, config)
|
||||
|
||||
# Save results
|
||||
save_results(flattened_results, stats, config)
|
||||
print(f"[MAIN] Results saved to: {config['log_path']}")
|
||||
debug_log.log_results_saved(config["log_path"], config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -42,6 +42,9 @@ from sc.preprocess.shuffle_data import ShuffledDataLoader, load_shuffled_data
|
||||
from sc.core.scagent import SCAgent
|
||||
from sc.core.model import load_models
|
||||
from sc.core.agent_pool import AgentPool
|
||||
from sc import debug_log
|
||||
|
||||
log = debug_log.log
|
||||
|
||||
|
||||
class QueueRandomSampler:
|
||||
@@ -85,10 +88,11 @@ class QueueRandomSampler:
|
||||
# Tracking for statistics (mimics Queue class interface)
|
||||
self._total_refreshed = 0
|
||||
|
||||
print(f"[QueueRandomSampler] Initialized with {len(example_dataset)} examples")
|
||||
print(f" Classes: {self.classes}")
|
||||
print(f" Queue size: {queue_size}")
|
||||
print(f" Policy: ALL elements refreshed every step")
|
||||
debug_log.log_queue_random_sampler_init(
|
||||
len(example_dataset),
|
||||
self.classes,
|
||||
queue_size,
|
||||
)
|
||||
|
||||
def _sample_one_set(self) -> List[int]:
|
||||
"""Sample one ICL example set (one example per class)."""
|
||||
@@ -167,7 +171,7 @@ async def run_queue_random_experiment(
|
||||
|
||||
example_dataset = dataloader.get_examples()
|
||||
if len(example_dataset) == 0:
|
||||
print(f"[ERROR] No examples found for user {user_id}")
|
||||
log(f"[ERROR] No examples found for user {user_id}")
|
||||
return []
|
||||
|
||||
# Initialize Queue Random Sampler (instead of regular Queue)
|
||||
@@ -189,24 +193,21 @@ async def run_queue_random_experiment(
|
||||
all_predictions = []
|
||||
all_ground_truths = []
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"QUEUE RANDOM BASELINE EXPERIMENT")
|
||||
print(f"User: {user_id} | Shuffle Seed: {shuffle_seed}")
|
||||
print(f"Total samples: {len(dataloader)} | Queue size: {queue_size}")
|
||||
print(f"Policy: ALL {queue_size} queue elements refreshed EVERY step")
|
||||
print(f"{'='*80}\n")
|
||||
debug_log.log_policy_experiment_header(
|
||||
"QUEUE RANDOM BASELINE",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
len(dataloader),
|
||||
queue_size,
|
||||
policy_note=f"Policy: ALL {queue_size} queue elements refreshed EVERY step",
|
||||
)
|
||||
|
||||
for processed_count, sample in enumerate(dataloader):
|
||||
# CRITICAL: Refresh ALL queue elements before each inference
|
||||
queue_sampler.refresh_all()
|
||||
|
||||
# Log queue state (all new random samples)
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[Sample {processed_count}] User {user_id} | QUEUE RANDOM Policy")
|
||||
print(f"Queue State (ALL FRESH RANDOM samples):")
|
||||
for idx, ex_idcs in enumerate(queue_sampler):
|
||||
print(f" [{idx}] Example indices: {ex_idcs}")
|
||||
print(f"{'#'*60}\n")
|
||||
debug_log.log_queue_random_queue_state(processed_count, user_id, queue_sampler)
|
||||
|
||||
# Create agent pool
|
||||
agent_pool = AgentPool(log_path=config["log_path"])
|
||||
@@ -228,26 +229,26 @@ async def run_queue_random_experiment(
|
||||
)
|
||||
agent_pool.add_agent(agent)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to create agents: {e}")
|
||||
log(f"[ERROR] Failed to create agents: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if len(agent_pool.agents) == 0:
|
||||
print(f"[WARN] No agents created for sample {processed_count}")
|
||||
log(f"[WARN] No agents created for sample {processed_count}")
|
||||
continue
|
||||
|
||||
# Run parallel interpretation
|
||||
try:
|
||||
interpretation_result = await agent_pool.run_parallel_interpretation()
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Interpretation failed: {e}")
|
||||
log(f"[ERROR] Interpretation failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if interpretation_result is None:
|
||||
print(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
log(f"[WARN] Interpretation failed for sample {processed_count}")
|
||||
continue
|
||||
|
||||
answer, queue_idcs, avg_confidence, consistency, responses = interpretation_result
|
||||
@@ -271,14 +272,18 @@ async def run_queue_random_experiment(
|
||||
all_ground_truths.append(ground_truth)
|
||||
|
||||
# Performance logging
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[RESULT] Sample {processed_count} | QUEUE RANDOM Policy")
|
||||
print(f" Answer: {answer} | GT: {ground_truth} | {'✓ CORRECT' if is_correct else '✗ WRONG'}")
|
||||
print(f" Confidence: {avg_confidence:.4f} | Consistency: {consistency:.4f}")
|
||||
print(f" Cumulative Accuracy: {cumulative_accuracy:.4f} ({cumulative_correct}/{processed_count + 1})")
|
||||
print(f" Window Accuracy (last {len(recent_results)}): {window_accuracy:.4f}")
|
||||
print(f" [NOTE] Queue will be FULLY REFRESHED for next sample")
|
||||
print(f"{'='*60}\n")
|
||||
debug_log.log_queue_random_result(
|
||||
processed_count,
|
||||
answer,
|
||||
ground_truth,
|
||||
is_correct,
|
||||
avg_confidence,
|
||||
consistency,
|
||||
cumulative_accuracy,
|
||||
cumulative_correct,
|
||||
window_accuracy,
|
||||
recent_results,
|
||||
)
|
||||
|
||||
# NO QUEUE UPDATE based on scores - just fresh random next time
|
||||
# (This is the key difference from Confidence/Consistency policies)
|
||||
@@ -303,13 +308,7 @@ async def run_queue_random_experiment(
|
||||
|
||||
# Final statistics
|
||||
sampler_stats = queue_sampler.get_statistics()
|
||||
print(f"\n{'#'*60}")
|
||||
print(f"[FINAL] Queue Random Statistics")
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Steps: {sampler_stats['total_steps']}")
|
||||
print(f" Total Refreshed: {sampler_stats['total_refreshed']} example sets")
|
||||
print(f" Avg Refresh per Step: {sampler_stats['avg_refresh_per_step']} (always full)")
|
||||
print(f"{'#'*60}\n")
|
||||
debug_log.log_queue_random_stats(user_id, shuffle_seed, sampler_stats)
|
||||
|
||||
if results:
|
||||
results[-1]["queue_random_stats"] = sampler_stats
|
||||
@@ -444,19 +443,19 @@ def save_results(
|
||||
stats_path = os.path.join(output_dir, "statistics.json")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||||
print(f"[SAVE] Statistics: {stats_path}")
|
||||
log(f"[SAVE] Statistics: {stats_path}")
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(output_dir, "results.json")
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"[SAVE] Results: {results_path}")
|
||||
log(f"[SAVE] Results: {results_path}")
|
||||
|
||||
# Save config
|
||||
config_path = os.path.join(output_dir, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
print(f"[SAVE] Config: {config_path}")
|
||||
log(f"[SAVE] Config: {config_path}")
|
||||
|
||||
|
||||
def main(
|
||||
@@ -481,7 +480,7 @@ def main(
|
||||
python -m sc.run_sc_queue_random --user_id=5 --shuffle_seed=42
|
||||
python -m sc.run_sc_queue_random --user_id=15 --shuffle_seed=123
|
||||
"""
|
||||
print(f"[MAIN] Loading config: {config_path}")
|
||||
log(f"[MAIN] Loading config: {config_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
@@ -495,20 +494,18 @@ def main(
|
||||
os.makedirs(config["log_path"], exist_ok=True)
|
||||
|
||||
# Print experiment info
|
||||
print("=" * 80)
|
||||
print("QUEUE RANDOM BASELINE EXPERIMENT")
|
||||
print("(Queue structure maintained, but ALL elements refreshed every step)")
|
||||
print("=" * 80)
|
||||
print(f" User ID: {user_id}")
|
||||
print(f" Shuffle Seed: {shuffle_seed}")
|
||||
print(f" Queue Size: {config.get('queue_size', 5)}")
|
||||
print(f" SC Samples (Agents): {len(config.get('models', []))}")
|
||||
print(f" Log Path: {config['log_path']}")
|
||||
print(f" Policy: Queue Random (full refresh every step)")
|
||||
print("=" * 80)
|
||||
debug_log.log_policy_main_header(
|
||||
"QUEUE RANDOM BASELINE",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
config.get("queue_size", 5),
|
||||
len(config.get("models", [])),
|
||||
config["log_path"],
|
||||
policy_note=" Policy: Queue Random (full refresh every step)",
|
||||
)
|
||||
|
||||
# Load shuffled data
|
||||
print(f"\n[MAIN] Loading shuffled data for user {user_id}, seed {shuffle_seed}...")
|
||||
debug_log.log_policy_loading_data(user_id, shuffle_seed)
|
||||
dataloader = ShuffledDataLoader(
|
||||
data_path=config["data_path"],
|
||||
user_id=user_id,
|
||||
@@ -517,7 +514,7 @@ def main(
|
||||
)
|
||||
|
||||
# Load models
|
||||
print(f"\n[MAIN] Loading models...")
|
||||
debug_log.log_policy_loading_models()
|
||||
model_pool = load_models(
|
||||
config["models"],
|
||||
temperature=config.get("temperature", 0.0),
|
||||
@@ -525,7 +522,7 @@ def main(
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
print(f"\n[MAIN] Starting Queue Random experiment...")
|
||||
debug_log.log_policy_start(label="Queue Random experiment")
|
||||
results = asyncio.run(run_queue_random_experiment(
|
||||
dataloader=dataloader,
|
||||
model_pool=model_pool,
|
||||
@@ -539,28 +536,21 @@ def main(
|
||||
stats = compute_statistics(results, stages)
|
||||
|
||||
# Print final summary
|
||||
print("\n" + "=" * 80)
|
||||
print("EXPERIMENT COMPLETE - QUEUE RANDOM BASELINE")
|
||||
print("=" * 80)
|
||||
print(f" User: {user_id} | Seed: {shuffle_seed}")
|
||||
print(f" Total Samples: {stats.get('total_samples', 0)}")
|
||||
print(f" Overall Accuracy: {stats.get('accuracy', 0):.4f}")
|
||||
print(f" Macro F1: {stats.get('macro_f1', 0):.4f}")
|
||||
print(f"\n Per-Stage Accuracy:")
|
||||
for stage, acc in stats.get("stage_accuracy", {}).items():
|
||||
count = stats.get("stage_sample_counts", {}).get(stage, 0)
|
||||
print(f" {stage}: {acc:.4f} (n={count})")
|
||||
print(f"\n Temporal Analysis:")
|
||||
temporal = stats.get("temporal_analysis", {})
|
||||
print(f" First Half: {temporal.get('first_half_accuracy', 0):.4f}")
|
||||
print(f" Second Half: {temporal.get('second_half_accuracy', 0):.4f}")
|
||||
print(f" Improvement: {temporal.get('improvement', 0):+.4f}")
|
||||
print(f"\n [EXPECTED] Improvement should be ~0 (no cumulative learning)")
|
||||
print("=" * 80)
|
||||
debug_log.log_policy_complete_summary(
|
||||
"QUEUE RANDOM BASELINE",
|
||||
user_id,
|
||||
shuffle_seed,
|
||||
stats,
|
||||
stats.get("stage_accuracy", {}),
|
||||
stats.get("stage_sample_counts", {}),
|
||||
temporal,
|
||||
expected_note="\n [EXPECTED] Improvement should be ~0 (no cumulative learning)",
|
||||
)
|
||||
|
||||
# Save results
|
||||
save_results(results, stats, config, user_id, shuffle_seed)
|
||||
print(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
log(f"\n[MAIN] Results saved to: {config['log_path']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user