260119:chronos_2_based_embedding_running
This commit is contained in:
281
core/embedding_index.py
Normal file
281
core/embedding_index.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""
|
||||||
|
Embedding Index for Similarity-based Example Selection
|
||||||
|
|
||||||
|
This module provides functionality to:
|
||||||
|
1. Load pre-computed Chronos-2 embeddings
|
||||||
|
2. Build an index for fast nearest neighbor search
|
||||||
|
3. Find similar examples based on embedding distance
|
||||||
|
|
||||||
|
The embedding index enables ICL (In-Context Learning) example selection
|
||||||
|
based on semantic similarity rather than random sampling.
|
||||||
|
|
||||||
|
Similarity Strategies:
|
||||||
|
- out_similar: Find similar examples from OTHER users (cross-user transfer)
|
||||||
|
- in_similar: Find similar examples from SAME user (personalization)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
index = EmbeddingIndex(embedding_path)
|
||||||
|
similar_indices = index.find_similar(query_embedding, k=5, exclude_user="01")
|
||||||
|
|
||||||
|
Author: NMSL Research Team
|
||||||
|
Date: 2026-01-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Tuple, Optional, Dict, Any
|
||||||
|
from datasets import load_from_disk, Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingIndex:
|
||||||
|
"""
|
||||||
|
Index for fast similarity search over pre-computed embeddings.
|
||||||
|
|
||||||
|
Uses cosine similarity for finding nearest neighbors in embedding space.
|
||||||
|
Supports filtering by user_id and session_id for controlled experiments.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
embeddings: numpy array of shape (num_samples, embedding_dim)
|
||||||
|
user_ids: list of user identifiers for each sample
|
||||||
|
session_ids: list of session identifiers for each sample
|
||||||
|
labels: list of sleep stage labels for each sample
|
||||||
|
indices: list of original indices in the dataset
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, embedding_path: str):
|
||||||
|
"""
|
||||||
|
Initialize the embedding index from a HuggingFace dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_path: Path to directory containing saved embeddings dataset
|
||||||
|
(output from gen_plot.py extract command)
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(embedding_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Embedding directory not found: {embedding_path}. "
|
||||||
|
"Run 'python gen_plot.py extract' first to generate embeddings."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[EmbeddingIndex] Loading embeddings from: {embedding_path}")
|
||||||
|
self.dataset = load_from_disk(embedding_path)
|
||||||
|
|
||||||
|
# Extract arrays for fast access
|
||||||
|
self.embeddings = np.array(self.dataset["embedding"], dtype=np.float32)
|
||||||
|
self.user_ids = np.array([str(uid) for uid in self.dataset["user_id"]])
|
||||||
|
self.session_ids = np.array([str(sid) for sid in self.dataset["session_id"]])
|
||||||
|
self.labels = np.array([str(label) for label in self.dataset["label"]])
|
||||||
|
self.indices = np.array(self.dataset["idx"])
|
||||||
|
|
||||||
|
# Normalize embeddings for cosine similarity (dot product of unit vectors)
|
||||||
|
self.embeddings_normalized = self._normalize(self.embeddings)
|
||||||
|
|
||||||
|
# Build lookup indices for fast filtering
|
||||||
|
self._build_lookup_indices()
|
||||||
|
|
||||||
|
print(f"[EmbeddingIndex] Loaded {len(self.embeddings)} samples")
|
||||||
|
print(f"[EmbeddingIndex] Embedding dimension: {self.embeddings.shape[1]}")
|
||||||
|
print(f"[EmbeddingIndex] Unique users: {len(np.unique(self.user_ids))}")
|
||||||
|
print(f"[EmbeddingIndex] Unique sessions: {len(np.unique(self.session_ids))}")
|
||||||
|
|
||||||
|
def _normalize(self, vectors: np.ndarray) -> np.ndarray:
|
||||||
|
"""L2 normalize vectors for cosine similarity computation."""
|
||||||
|
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||||||
|
norms = np.where(norms == 0, 1, norms) # Avoid division by zero
|
||||||
|
return vectors / norms
|
||||||
|
|
||||||
|
def _build_lookup_indices(self):
|
||||||
|
"""Build dictionaries for fast filtering by user/session/label."""
|
||||||
|
self.user_to_indices: Dict[str, np.ndarray] = {}
|
||||||
|
self.session_to_indices: Dict[str, np.ndarray] = {}
|
||||||
|
self.label_to_indices: Dict[str, np.ndarray] = {}
|
||||||
|
|
||||||
|
for user_id in np.unique(self.user_ids):
|
||||||
|
self.user_to_indices[user_id] = np.where(self.user_ids == user_id)[0]
|
||||||
|
|
||||||
|
for session_id in np.unique(self.session_ids):
|
||||||
|
self.session_to_indices[session_id] = np.where(self.session_ids == session_id)[0]
|
||||||
|
|
||||||
|
for label in np.unique(self.labels):
|
||||||
|
self.label_to_indices[label] = np.where(self.labels == label)[0]
|
||||||
|
|
||||||
|
def get_embedding_by_key(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
idx: int
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Get embedding for a specific sample identified by (user_id, session_id, idx).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User identifier
|
||||||
|
session_id: Session identifier (1 or 2)
|
||||||
|
idx: Sample index within the session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding vector or None if not found
|
||||||
|
"""
|
||||||
|
mask = (
|
||||||
|
(self.user_ids == str(user_id)) &
|
||||||
|
(self.session_ids == str(session_id)) &
|
||||||
|
(self.indices == idx)
|
||||||
|
)
|
||||||
|
matches = np.where(mask)[0]
|
||||||
|
|
||||||
|
if len(matches) == 0:
|
||||||
|
return None
|
||||||
|
return self.embeddings[matches[0]]
|
||||||
|
|
||||||
|
def cosine_similarity(
|
||||||
|
self,
|
||||||
|
query: np.ndarray,
|
||||||
|
candidates: np.ndarray
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute cosine similarity between query and candidate vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query vector of shape (embedding_dim,)
|
||||||
|
candidates: Candidate matrix of shape (num_candidates, embedding_dim)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Similarity scores of shape (num_candidates,)
|
||||||
|
"""
|
||||||
|
query_normalized = query / (np.linalg.norm(query) + 1e-8)
|
||||||
|
return candidates @ query_normalized
|
||||||
|
|
||||||
|
def find_similar(
|
||||||
|
self,
|
||||||
|
query_embedding: np.ndarray,
|
||||||
|
k: int = 5,
|
||||||
|
exclude_user: Optional[str] = None,
|
||||||
|
include_user: Optional[str] = None,
|
||||||
|
filter_session: Optional[str] = None,
|
||||||
|
filter_label: Optional[str] = None,
|
||||||
|
) -> List[Tuple[int, float, Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Find k most similar samples to the query embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_embedding: Query vector of shape (embedding_dim,)
|
||||||
|
k: Number of nearest neighbors to return
|
||||||
|
exclude_user: User ID to exclude from search (for out_similar)
|
||||||
|
include_user: User ID to include only (for in_similar)
|
||||||
|
filter_session: Only search in this session (e.g., "1" for train set)
|
||||||
|
filter_label: Only return samples with this label
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (index, similarity, metadata) tuples sorted by similarity (descending)
|
||||||
|
metadata contains: user_id, session_id, idx, label
|
||||||
|
"""
|
||||||
|
# Build candidate mask based on filters
|
||||||
|
candidate_mask = np.ones(len(self.embeddings), dtype=bool)
|
||||||
|
|
||||||
|
if exclude_user is not None:
|
||||||
|
candidate_mask &= (self.user_ids != str(exclude_user))
|
||||||
|
|
||||||
|
if include_user is not None:
|
||||||
|
candidate_mask &= (self.user_ids == str(include_user))
|
||||||
|
|
||||||
|
if filter_session is not None:
|
||||||
|
candidate_mask &= (self.session_ids == str(filter_session))
|
||||||
|
|
||||||
|
if filter_label is not None:
|
||||||
|
candidate_mask &= (self.labels == str(filter_label))
|
||||||
|
|
||||||
|
candidate_indices = np.where(candidate_mask)[0]
|
||||||
|
|
||||||
|
if len(candidate_indices) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Compute similarities
|
||||||
|
candidate_embeddings = self.embeddings_normalized[candidate_indices]
|
||||||
|
similarities = self.cosine_similarity(query_embedding, candidate_embeddings)
|
||||||
|
|
||||||
|
# Get top-k
|
||||||
|
k = min(k, len(similarities))
|
||||||
|
top_k_local = np.argsort(similarities)[::-1][:k]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for local_idx in top_k_local:
|
||||||
|
global_idx = candidate_indices[local_idx]
|
||||||
|
sim = similarities[local_idx]
|
||||||
|
metadata = {
|
||||||
|
"user_id": self.user_ids[global_idx],
|
||||||
|
"session_id": self.session_ids[global_idx],
|
||||||
|
"idx": int(self.indices[global_idx]),
|
||||||
|
"label": self.labels[global_idx],
|
||||||
|
}
|
||||||
|
results.append((global_idx, float(sim), metadata))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def find_similar_per_class(
|
||||||
|
self,
|
||||||
|
query_embedding: np.ndarray,
|
||||||
|
classes: List[str],
|
||||||
|
k_per_class: int = 1,
|
||||||
|
exclude_user: Optional[str] = None,
|
||||||
|
include_user: Optional[str] = None,
|
||||||
|
filter_session: Optional[str] = "1",
|
||||||
|
) -> Dict[str, List[Tuple[int, float, Dict[str, Any]]]]:
|
||||||
|
"""
|
||||||
|
Find k most similar samples for each class.
|
||||||
|
|
||||||
|
This ensures balanced class representation in ICL examples,
|
||||||
|
similar to the original random sampling approach.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_embedding: Query vector
|
||||||
|
classes: List of class labels to search
|
||||||
|
k_per_class: Number of examples per class
|
||||||
|
exclude_user: User to exclude (out_similar mode)
|
||||||
|
include_user: User to include only (in_similar mode)
|
||||||
|
filter_session: Session to search in (default: "1" = train set)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping class label to list of similar samples
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
for cls in classes:
|
||||||
|
similar = self.find_similar(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
k=k_per_class,
|
||||||
|
exclude_user=exclude_user,
|
||||||
|
include_user=include_user,
|
||||||
|
filter_session=filter_session,
|
||||||
|
filter_label=cls,
|
||||||
|
)
|
||||||
|
results[cls] = similar
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_sample_metadata(self, global_idx: int) -> Dict[str, Any]:
|
||||||
|
"""Get metadata for a sample by its global index."""
|
||||||
|
return {
|
||||||
|
"user_id": self.user_ids[global_idx],
|
||||||
|
"session_id": self.session_ids[global_idx],
|
||||||
|
"idx": int(self.indices[global_idx]),
|
||||||
|
"label": self.labels[global_idx],
|
||||||
|
"embedding": self.embeddings[global_idx],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_index(embedding_path: str) -> Optional[EmbeddingIndex]:
|
||||||
|
"""
|
||||||
|
Factory function to create an EmbeddingIndex, returning None if path doesn't exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_path: Path to embeddings directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingIndex instance or None
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(embedding_path):
|
||||||
|
print(f"[WARNING] Embedding path not found: {embedding_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return EmbeddingIndex(embedding_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to load embeddings: {e}")
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user