260119:chronos_2_based_embedding_running

This commit is contained in:
ssum21
2026-01-19 20:17:09 +09:00
parent b0e656038b
commit 01c49e2f1a

281
core/embedding_index.py Normal file
View File

@@ -0,0 +1,281 @@
"""
Embedding Index for Similarity-based Example Selection
This module provides functionality to:
1. Load pre-computed Chronos-2 embeddings
2. Build an index for fast nearest neighbor search
3. Find similar examples based on embedding distance
The embedding index enables ICL (In-Context Learning) example selection
based on semantic similarity rather than random sampling.
Similarity Strategies:
- out_similar: Find similar examples from OTHER users (cross-user transfer)
- in_similar: Find similar examples from SAME user (personalization)
Usage:
index = EmbeddingIndex(embedding_path)
similar_indices = index.find_similar(query_embedding, k=5, exclude_user="01")
Author: NMSL Research Team
Date: 2026-01-16
"""
import os
import numpy as np
from typing import List, Tuple, Optional, Dict, Any
from datasets import load_from_disk, Dataset
class EmbeddingIndex:
"""
Index for fast similarity search over pre-computed embeddings.
Uses cosine similarity for finding nearest neighbors in embedding space.
Supports filtering by user_id and session_id for controlled experiments.
Attributes:
embeddings: numpy array of shape (num_samples, embedding_dim)
user_ids: list of user identifiers for each sample
session_ids: list of session identifiers for each sample
labels: list of sleep stage labels for each sample
indices: list of original indices in the dataset
"""
def __init__(self, embedding_path: str):
"""
Initialize the embedding index from a HuggingFace dataset.
Args:
embedding_path: Path to directory containing saved embeddings dataset
(output from gen_plot.py extract command)
"""
if not os.path.isdir(embedding_path):
raise FileNotFoundError(
f"Embedding directory not found: {embedding_path}. "
"Run 'python gen_plot.py extract' first to generate embeddings."
)
print(f"[EmbeddingIndex] Loading embeddings from: {embedding_path}")
self.dataset = load_from_disk(embedding_path)
# Extract arrays for fast access
self.embeddings = np.array(self.dataset["embedding"], dtype=np.float32)
self.user_ids = np.array([str(uid) for uid in self.dataset["user_id"]])
self.session_ids = np.array([str(sid) for sid in self.dataset["session_id"]])
self.labels = np.array([str(label) for label in self.dataset["label"]])
self.indices = np.array(self.dataset["idx"])
# Normalize embeddings for cosine similarity (dot product of unit vectors)
self.embeddings_normalized = self._normalize(self.embeddings)
# Build lookup indices for fast filtering
self._build_lookup_indices()
print(f"[EmbeddingIndex] Loaded {len(self.embeddings)} samples")
print(f"[EmbeddingIndex] Embedding dimension: {self.embeddings.shape[1]}")
print(f"[EmbeddingIndex] Unique users: {len(np.unique(self.user_ids))}")
print(f"[EmbeddingIndex] Unique sessions: {len(np.unique(self.session_ids))}")
def _normalize(self, vectors: np.ndarray) -> np.ndarray:
"""L2 normalize vectors for cosine similarity computation."""
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms = np.where(norms == 0, 1, norms) # Avoid division by zero
return vectors / norms
def _build_lookup_indices(self):
"""Build dictionaries for fast filtering by user/session/label."""
self.user_to_indices: Dict[str, np.ndarray] = {}
self.session_to_indices: Dict[str, np.ndarray] = {}
self.label_to_indices: Dict[str, np.ndarray] = {}
for user_id in np.unique(self.user_ids):
self.user_to_indices[user_id] = np.where(self.user_ids == user_id)[0]
for session_id in np.unique(self.session_ids):
self.session_to_indices[session_id] = np.where(self.session_ids == session_id)[0]
for label in np.unique(self.labels):
self.label_to_indices[label] = np.where(self.labels == label)[0]
def get_embedding_by_key(
self,
user_id: str,
session_id: str,
idx: int
) -> Optional[np.ndarray]:
"""
Get embedding for a specific sample identified by (user_id, session_id, idx).
Args:
user_id: User identifier
session_id: Session identifier (1 or 2)
idx: Sample index within the session
Returns:
Embedding vector or None if not found
"""
mask = (
(self.user_ids == str(user_id)) &
(self.session_ids == str(session_id)) &
(self.indices == idx)
)
matches = np.where(mask)[0]
if len(matches) == 0:
return None
return self.embeddings[matches[0]]
def cosine_similarity(
self,
query: np.ndarray,
candidates: np.ndarray
) -> np.ndarray:
"""
Compute cosine similarity between query and candidate vectors.
Args:
query: Query vector of shape (embedding_dim,)
candidates: Candidate matrix of shape (num_candidates, embedding_dim)
Returns:
Similarity scores of shape (num_candidates,)
"""
query_normalized = query / (np.linalg.norm(query) + 1e-8)
return candidates @ query_normalized
def find_similar(
self,
query_embedding: np.ndarray,
k: int = 5,
exclude_user: Optional[str] = None,
include_user: Optional[str] = None,
filter_session: Optional[str] = None,
filter_label: Optional[str] = None,
) -> List[Tuple[int, float, Dict[str, Any]]]:
"""
Find k most similar samples to the query embedding.
Args:
query_embedding: Query vector of shape (embedding_dim,)
k: Number of nearest neighbors to return
exclude_user: User ID to exclude from search (for out_similar)
include_user: User ID to include only (for in_similar)
filter_session: Only search in this session (e.g., "1" for train set)
filter_label: Only return samples with this label
Returns:
List of (index, similarity, metadata) tuples sorted by similarity (descending)
metadata contains: user_id, session_id, idx, label
"""
# Build candidate mask based on filters
candidate_mask = np.ones(len(self.embeddings), dtype=bool)
if exclude_user is not None:
candidate_mask &= (self.user_ids != str(exclude_user))
if include_user is not None:
candidate_mask &= (self.user_ids == str(include_user))
if filter_session is not None:
candidate_mask &= (self.session_ids == str(filter_session))
if filter_label is not None:
candidate_mask &= (self.labels == str(filter_label))
candidate_indices = np.where(candidate_mask)[0]
if len(candidate_indices) == 0:
return []
# Compute similarities
candidate_embeddings = self.embeddings_normalized[candidate_indices]
similarities = self.cosine_similarity(query_embedding, candidate_embeddings)
# Get top-k
k = min(k, len(similarities))
top_k_local = np.argsort(similarities)[::-1][:k]
results = []
for local_idx in top_k_local:
global_idx = candidate_indices[local_idx]
sim = similarities[local_idx]
metadata = {
"user_id": self.user_ids[global_idx],
"session_id": self.session_ids[global_idx],
"idx": int(self.indices[global_idx]),
"label": self.labels[global_idx],
}
results.append((global_idx, float(sim), metadata))
return results
def find_similar_per_class(
self,
query_embedding: np.ndarray,
classes: List[str],
k_per_class: int = 1,
exclude_user: Optional[str] = None,
include_user: Optional[str] = None,
filter_session: Optional[str] = "1",
) -> Dict[str, List[Tuple[int, float, Dict[str, Any]]]]:
"""
Find k most similar samples for each class.
This ensures balanced class representation in ICL examples,
similar to the original random sampling approach.
Args:
query_embedding: Query vector
classes: List of class labels to search
k_per_class: Number of examples per class
exclude_user: User to exclude (out_similar mode)
include_user: User to include only (in_similar mode)
filter_session: Session to search in (default: "1" = train set)
Returns:
Dictionary mapping class label to list of similar samples
"""
results = {}
for cls in classes:
similar = self.find_similar(
query_embedding=query_embedding,
k=k_per_class,
exclude_user=exclude_user,
include_user=include_user,
filter_session=filter_session,
filter_label=cls,
)
results[cls] = similar
return results
def get_sample_metadata(self, global_idx: int) -> Dict[str, Any]:
"""Get metadata for a sample by its global index."""
return {
"user_id": self.user_ids[global_idx],
"session_id": self.session_ids[global_idx],
"idx": int(self.indices[global_idx]),
"label": self.labels[global_idx],
"embedding": self.embeddings[global_idx],
}
def create_embedding_index(embedding_path: str) -> Optional[EmbeddingIndex]:
"""
Factory function to create an EmbeddingIndex, returning None if path doesn't exist.
Args:
embedding_path: Path to embeddings directory
Returns:
EmbeddingIndex instance or None
"""
if not os.path.isdir(embedding_path):
print(f"[WARNING] Embedding path not found: {embedding_path}")
return None
try:
return EmbeddingIndex(embedding_path)
except Exception as e:
print(f"[ERROR] Failed to load embeddings: {e}")
return None