260119:chronos_2_based_embedding_running
This commit is contained in:
281
core/embedding_index.py
Normal file
281
core/embedding_index.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Embedding Index for Similarity-based Example Selection
|
||||
|
||||
This module provides functionality to:
|
||||
1. Load pre-computed Chronos-2 embeddings
|
||||
2. Build an index for fast nearest neighbor search
|
||||
3. Find similar examples based on embedding distance
|
||||
|
||||
The embedding index enables ICL (In-Context Learning) example selection
|
||||
based on semantic similarity rather than random sampling.
|
||||
|
||||
Similarity Strategies:
|
||||
- out_similar: Find similar examples from OTHER users (cross-user transfer)
|
||||
- in_similar: Find similar examples from SAME user (personalization)
|
||||
|
||||
Usage:
|
||||
index = EmbeddingIndex(embedding_path)
|
||||
similar_indices = index.find_similar(query_embedding, k=5, exclude_user="01")
|
||||
|
||||
Author: NMSL Research Team
|
||||
Date: 2026-01-16
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
from datasets import load_from_disk, Dataset
|
||||
|
||||
|
||||
class EmbeddingIndex:
|
||||
"""
|
||||
Index for fast similarity search over pre-computed embeddings.
|
||||
|
||||
Uses cosine similarity for finding nearest neighbors in embedding space.
|
||||
Supports filtering by user_id and session_id for controlled experiments.
|
||||
|
||||
Attributes:
|
||||
embeddings: numpy array of shape (num_samples, embedding_dim)
|
||||
user_ids: list of user identifiers for each sample
|
||||
session_ids: list of session identifiers for each sample
|
||||
labels: list of sleep stage labels for each sample
|
||||
indices: list of original indices in the dataset
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_path: str):
|
||||
"""
|
||||
Initialize the embedding index from a HuggingFace dataset.
|
||||
|
||||
Args:
|
||||
embedding_path: Path to directory containing saved embeddings dataset
|
||||
(output from gen_plot.py extract command)
|
||||
"""
|
||||
if not os.path.isdir(embedding_path):
|
||||
raise FileNotFoundError(
|
||||
f"Embedding directory not found: {embedding_path}. "
|
||||
"Run 'python gen_plot.py extract' first to generate embeddings."
|
||||
)
|
||||
|
||||
print(f"[EmbeddingIndex] Loading embeddings from: {embedding_path}")
|
||||
self.dataset = load_from_disk(embedding_path)
|
||||
|
||||
# Extract arrays for fast access
|
||||
self.embeddings = np.array(self.dataset["embedding"], dtype=np.float32)
|
||||
self.user_ids = np.array([str(uid) for uid in self.dataset["user_id"]])
|
||||
self.session_ids = np.array([str(sid) for sid in self.dataset["session_id"]])
|
||||
self.labels = np.array([str(label) for label in self.dataset["label"]])
|
||||
self.indices = np.array(self.dataset["idx"])
|
||||
|
||||
# Normalize embeddings for cosine similarity (dot product of unit vectors)
|
||||
self.embeddings_normalized = self._normalize(self.embeddings)
|
||||
|
||||
# Build lookup indices for fast filtering
|
||||
self._build_lookup_indices()
|
||||
|
||||
print(f"[EmbeddingIndex] Loaded {len(self.embeddings)} samples")
|
||||
print(f"[EmbeddingIndex] Embedding dimension: {self.embeddings.shape[1]}")
|
||||
print(f"[EmbeddingIndex] Unique users: {len(np.unique(self.user_ids))}")
|
||||
print(f"[EmbeddingIndex] Unique sessions: {len(np.unique(self.session_ids))}")
|
||||
|
||||
def _normalize(self, vectors: np.ndarray) -> np.ndarray:
|
||||
"""L2 normalize vectors for cosine similarity computation."""
|
||||
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||||
norms = np.where(norms == 0, 1, norms) # Avoid division by zero
|
||||
return vectors / norms
|
||||
|
||||
def _build_lookup_indices(self):
|
||||
"""Build dictionaries for fast filtering by user/session/label."""
|
||||
self.user_to_indices: Dict[str, np.ndarray] = {}
|
||||
self.session_to_indices: Dict[str, np.ndarray] = {}
|
||||
self.label_to_indices: Dict[str, np.ndarray] = {}
|
||||
|
||||
for user_id in np.unique(self.user_ids):
|
||||
self.user_to_indices[user_id] = np.where(self.user_ids == user_id)[0]
|
||||
|
||||
for session_id in np.unique(self.session_ids):
|
||||
self.session_to_indices[session_id] = np.where(self.session_ids == session_id)[0]
|
||||
|
||||
for label in np.unique(self.labels):
|
||||
self.label_to_indices[label] = np.where(self.labels == label)[0]
|
||||
|
||||
def get_embedding_by_key(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
idx: int
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get embedding for a specific sample identified by (user_id, session_id, idx).
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
session_id: Session identifier (1 or 2)
|
||||
idx: Sample index within the session
|
||||
|
||||
Returns:
|
||||
Embedding vector or None if not found
|
||||
"""
|
||||
mask = (
|
||||
(self.user_ids == str(user_id)) &
|
||||
(self.session_ids == str(session_id)) &
|
||||
(self.indices == idx)
|
||||
)
|
||||
matches = np.where(mask)[0]
|
||||
|
||||
if len(matches) == 0:
|
||||
return None
|
||||
return self.embeddings[matches[0]]
|
||||
|
||||
def cosine_similarity(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
candidates: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute cosine similarity between query and candidate vectors.
|
||||
|
||||
Args:
|
||||
query: Query vector of shape (embedding_dim,)
|
||||
candidates: Candidate matrix of shape (num_candidates, embedding_dim)
|
||||
|
||||
Returns:
|
||||
Similarity scores of shape (num_candidates,)
|
||||
"""
|
||||
query_normalized = query / (np.linalg.norm(query) + 1e-8)
|
||||
return candidates @ query_normalized
|
||||
|
||||
def find_similar(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
k: int = 5,
|
||||
exclude_user: Optional[str] = None,
|
||||
include_user: Optional[str] = None,
|
||||
filter_session: Optional[str] = None,
|
||||
filter_label: Optional[str] = None,
|
||||
) -> List[Tuple[int, float, Dict[str, Any]]]:
|
||||
"""
|
||||
Find k most similar samples to the query embedding.
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector of shape (embedding_dim,)
|
||||
k: Number of nearest neighbors to return
|
||||
exclude_user: User ID to exclude from search (for out_similar)
|
||||
include_user: User ID to include only (for in_similar)
|
||||
filter_session: Only search in this session (e.g., "1" for train set)
|
||||
filter_label: Only return samples with this label
|
||||
|
||||
Returns:
|
||||
List of (index, similarity, metadata) tuples sorted by similarity (descending)
|
||||
metadata contains: user_id, session_id, idx, label
|
||||
"""
|
||||
# Build candidate mask based on filters
|
||||
candidate_mask = np.ones(len(self.embeddings), dtype=bool)
|
||||
|
||||
if exclude_user is not None:
|
||||
candidate_mask &= (self.user_ids != str(exclude_user))
|
||||
|
||||
if include_user is not None:
|
||||
candidate_mask &= (self.user_ids == str(include_user))
|
||||
|
||||
if filter_session is not None:
|
||||
candidate_mask &= (self.session_ids == str(filter_session))
|
||||
|
||||
if filter_label is not None:
|
||||
candidate_mask &= (self.labels == str(filter_label))
|
||||
|
||||
candidate_indices = np.where(candidate_mask)[0]
|
||||
|
||||
if len(candidate_indices) == 0:
|
||||
return []
|
||||
|
||||
# Compute similarities
|
||||
candidate_embeddings = self.embeddings_normalized[candidate_indices]
|
||||
similarities = self.cosine_similarity(query_embedding, candidate_embeddings)
|
||||
|
||||
# Get top-k
|
||||
k = min(k, len(similarities))
|
||||
top_k_local = np.argsort(similarities)[::-1][:k]
|
||||
|
||||
results = []
|
||||
for local_idx in top_k_local:
|
||||
global_idx = candidate_indices[local_idx]
|
||||
sim = similarities[local_idx]
|
||||
metadata = {
|
||||
"user_id": self.user_ids[global_idx],
|
||||
"session_id": self.session_ids[global_idx],
|
||||
"idx": int(self.indices[global_idx]),
|
||||
"label": self.labels[global_idx],
|
||||
}
|
||||
results.append((global_idx, float(sim), metadata))
|
||||
|
||||
return results
|
||||
|
||||
def find_similar_per_class(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
classes: List[str],
|
||||
k_per_class: int = 1,
|
||||
exclude_user: Optional[str] = None,
|
||||
include_user: Optional[str] = None,
|
||||
filter_session: Optional[str] = "1",
|
||||
) -> Dict[str, List[Tuple[int, float, Dict[str, Any]]]]:
|
||||
"""
|
||||
Find k most similar samples for each class.
|
||||
|
||||
This ensures balanced class representation in ICL examples,
|
||||
similar to the original random sampling approach.
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector
|
||||
classes: List of class labels to search
|
||||
k_per_class: Number of examples per class
|
||||
exclude_user: User to exclude (out_similar mode)
|
||||
include_user: User to include only (in_similar mode)
|
||||
filter_session: Session to search in (default: "1" = train set)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping class label to list of similar samples
|
||||
"""
|
||||
results = {}
|
||||
for cls in classes:
|
||||
similar = self.find_similar(
|
||||
query_embedding=query_embedding,
|
||||
k=k_per_class,
|
||||
exclude_user=exclude_user,
|
||||
include_user=include_user,
|
||||
filter_session=filter_session,
|
||||
filter_label=cls,
|
||||
)
|
||||
results[cls] = similar
|
||||
return results
|
||||
|
||||
def get_sample_metadata(self, global_idx: int) -> Dict[str, Any]:
|
||||
"""Get metadata for a sample by its global index."""
|
||||
return {
|
||||
"user_id": self.user_ids[global_idx],
|
||||
"session_id": self.session_ids[global_idx],
|
||||
"idx": int(self.indices[global_idx]),
|
||||
"label": self.labels[global_idx],
|
||||
"embedding": self.embeddings[global_idx],
|
||||
}
|
||||
|
||||
|
||||
def create_embedding_index(embedding_path: str) -> Optional[EmbeddingIndex]:
|
||||
"""
|
||||
Factory function to create an EmbeddingIndex, returning None if path doesn't exist.
|
||||
|
||||
Args:
|
||||
embedding_path: Path to embeddings directory
|
||||
|
||||
Returns:
|
||||
EmbeddingIndex instance or None
|
||||
"""
|
||||
if not os.path.isdir(embedding_path):
|
||||
print(f"[WARNING] Embedding path not found: {embedding_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return EmbeddingIndex(embedding_path)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to load embeddings: {e}")
|
||||
return None
|
||||
Reference in New Issue
Block a user