Added SBERT metadata embedding extraction and visualization for PPG-BP dataset
This commit is contained in:
committed by
GitHub
parent
60ae0a5941
commit
056327731e
748
analysis/user_similarity/sbert_metadata_ppgbp/gen_plot.py
Normal file
748
analysis/user_similarity/sbert_metadata_ppgbp/gen_plot.py
Normal file
@@ -0,0 +1,748 @@
|
||||
"""
|
||||
SBERT Metadata-Based Embedding Extraction and Visualization Pipeline
|
||||
|
||||
This module provides functionality to
|
||||
1. Extract embeddings from user metadata (age and sex) using SBERT (Sentence-BERT)
|
||||
2. Visualize embeddings using dimensionality reduction (t-SNE) colored by age
|
||||
|
||||
SBERT Overview:
|
||||
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
|
||||
and triplet network structures to derive semantically meaningful sentence
|
||||
embeddings. It's designed for semantic similarity tasks and produces
|
||||
fixed-size dense vector representations of text.
|
||||
|
||||
Key Features:
|
||||
- Semantic understanding: Captures meaning rather than just word presence
|
||||
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
|
||||
- Efficient: Optimized for sentence-level tasks
|
||||
|
||||
Embedding Strategy:
|
||||
Instead of using time series features, we create textual descriptions based on
|
||||
user metadata (age and sex). This approach allows us to capture user-level
|
||||
characteristics in the embedding space.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Load user metadata from XLS file (age, sex)
|
||||
2. Textualization: Convert metadata to natural language description
|
||||
3. SBERT encoding: Generate 384-dimensional embeddings
|
||||
4. Visualization: t-SNE with continuous age coloring
|
||||
|
||||
Usage:
|
||||
# Extract embeddings from metadata for all labels
|
||||
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/all_labels
|
||||
|
||||
# Extract embeddings for a single label
|
||||
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/REM --label REM
|
||||
|
||||
# Visualize with t-SNE from all label directories (colored by age)
|
||||
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels
|
||||
|
||||
# Visualize from a single label directory
|
||||
python gen_plot.py plot --emb_dir ./embeddings/REM --out_dir ./plots/REM
|
||||
|
||||
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||
Date: 2026-01-09
|
||||
"""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from datasets import load_from_disk, Dataset, concatenate_datasets
|
||||
from fire import Fire
|
||||
from sklearn.manifold import TSNE
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
# =============================================================================
|
||||
|
||||
SUBJECT_PATH = "LLM_Health/tsllm_personalization_icl/analysis/user_similarity/sbert_metadata_ppgbp/PPGBP_metadata.xlsx"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metadata Loading
|
||||
# =============================================================================
|
||||
|
||||
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Load subject metadata from XLS file.
|
||||
|
||||
The XLS file contains subject information including:
|
||||
- subject: Subject ID (e.g., "SC4001", "SC4002")
|
||||
- age: Age of the subject
|
||||
- sex (F=1): Sex (1 = Female, 0 = Male)
|
||||
|
||||
Args:
|
||||
subject_path: Path to the SC-subjects.xls file
|
||||
|
||||
Returns:
|
||||
Dictionary mapping subject IDs to metadata dictionaries
|
||||
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
|
||||
"""
|
||||
|
||||
df = pd.read_excel(subject_path, header=1)
|
||||
subject_info = {}
|
||||
|
||||
for index, row in df.iterrows():
|
||||
subject_id = str(row["subject_ID"]).strip()
|
||||
|
||||
subject_info[subject_id] = {
|
||||
"sex": str(row["Sex(M/F)"]).strip() if pd.notna(row["Sex(M/F)"]) else None,
|
||||
"age": int(row["Age(year)"]) if pd.notna(row["Age(year)"]) else None,
|
||||
"height": int(row["Height(cm)"]) if pd.notna(row["Height(cm)"]) else None,
|
||||
"weight": int(row["Weight(kg)"]) if pd.notna(row["Weight(kg)"]) else None,
|
||||
"sbp": int(row["Systolic Blood Pressure(mmHg)"])
|
||||
if pd.notna(row["Systolic Blood Pressure(mmHg)"]) else None,
|
||||
"dbp": int(row["Diastolic Blood Pressure(mmHg)"])
|
||||
if pd.notna(row["Diastolic Blood Pressure(mmHg)"]) else None,
|
||||
"hr": int(row["Heart Rate(b/m)"])
|
||||
if pd.notna(row["Heart Rate(b/m)"]) else None,
|
||||
"bmi": float(row["BMI(kg/m^2)"]) if pd.notna(row["BMI(kg/m^2)"]) else None,
|
||||
"hypertension": str(row["Hypertension"]) if pd.notna(row["Hypertension"]) else None,
|
||||
}
|
||||
|
||||
return subject_info
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Extractor Class
|
||||
# =============================================================================
|
||||
|
||||
class SBERT_Metadata:
|
||||
"""
|
||||
Extracts fixed-dimensional embeddings from user metadata using SBERT.
|
||||
|
||||
Uses Sentence-BERT to convert textualized metadata descriptions into dense
|
||||
vector representations. The textualization process converts sex, age, height,
|
||||
weight, sbp, dbp, heart rate, bmi, hypertension information into natural language,
|
||||
which SBERT then encodes semantically.
|
||||
|
||||
Architecture:
|
||||
User Metadata → Textualization → SBERT Encoder → Embedding
|
||||
|
||||
Processing Pipeline:
|
||||
1. Load user metadata (sex, age, height, weight, sbp, dbp, heart rate, bmi,
|
||||
hypertension) from XLS file
|
||||
2. Textualize: Convert metadata to natural language description
|
||||
3. SBERT encoding: Generate 384-dimensional semantic embeddings
|
||||
4. Output: Fixed-size embedding vector per user
|
||||
|
||||
Model: all-MiniLM-L6-v2
|
||||
- Lightweight BERT variant optimized for sentence embeddings
|
||||
- 384-dimensional output embeddings
|
||||
- Fast inference with good semantic understanding
|
||||
|
||||
Attributes:
|
||||
model: SentenceTransformer instance for encoding textualized metadata
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the SBERT embedder with pre-trained model.
|
||||
|
||||
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
|
||||
optimized for sentence similarity tasks.
|
||||
"""
|
||||
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
def textualize_metadata_ppg_bp(self,
|
||||
sex: Optional[str],
|
||||
age: Optional[int],
|
||||
height: Optional[int],
|
||||
weight: Optional[int],
|
||||
sbp: Optional[int],
|
||||
dbp: Optional[int],
|
||||
heart_rate: Optional[int],
|
||||
bmi: Optional[float],
|
||||
hypertension: Optional[str]) -> str:
|
||||
"""
|
||||
Convert user metadata (age, sex, height, weight, sbp, dbp, heart rate, bmi, hypertension) into a natural language description.
|
||||
|
||||
This textualization step is crucial for SBERT, which expects text input.
|
||||
The description provides physiological and demographic information in a structured format.
|
||||
|
||||
Args:
|
||||
sex: Sex of the user (0 = Male, 1 = Female, may be None)
|
||||
age: Age of the user (integer, may be None)
|
||||
height_cm: Height in centimeters, (integer, may be None)
|
||||
weight_kg: Weight in kilograms, (integer, may be None)
|
||||
sbp_mmHg: Systolic blood pressure in mmHg, (integer, may be None)
|
||||
dbp_mmHg: Diastolic blood pressure in mmHg, (integer, may be None)
|
||||
heart_rate_bpm: Heart rate in beats per minute, (integer, may be None)
|
||||
bmi: Body mass index (kg/m^2), (float, may be None)
|
||||
hypertension: Hypertension status (0 = Normal, 1 = Prehypertension, 2 = Stage 1 hypertension, 3 = Stage 2 hypertension, may be None)
|
||||
|
||||
|
||||
Returns:
|
||||
Natural language string describing the user metadata
|
||||
"""
|
||||
# Map sex code to text
|
||||
if sex is not None:
|
||||
sex_text = "Female" if sex == 1 else "Male"
|
||||
else:
|
||||
sex_text = "Unknown"
|
||||
|
||||
# Map age code to text
|
||||
if age is not None:
|
||||
age_text = f"{age}"
|
||||
else:
|
||||
age_text = "unknown"
|
||||
|
||||
# Map height code to text
|
||||
if height is not None:
|
||||
height_text = f"{height} cm"
|
||||
else:
|
||||
height_text = "unknown"
|
||||
|
||||
# Map weight code to text
|
||||
if weight is not None:
|
||||
weight_text = f"{weight} kg"
|
||||
else:
|
||||
weight_text = "unknown"
|
||||
|
||||
# Map sbp code to text
|
||||
if sbp is not None:
|
||||
sbp_text = f"{sbp} mmHg"
|
||||
else:
|
||||
sbp_text = "unknown"
|
||||
|
||||
# Map dbp code to text
|
||||
if dbp is not None:
|
||||
dbp_text = f"{dbp} mmHg"
|
||||
else:
|
||||
dbp_text = "unknown"
|
||||
|
||||
# Map heart rate code to text
|
||||
if heart_rate is not None:
|
||||
heart_rate_text = f"{heart_rate} bpm"
|
||||
else:
|
||||
heart_rate_text = "unknown"
|
||||
|
||||
# Map bmi code to text
|
||||
if bmi is not None:
|
||||
bmi_text = f"{bmi} kg/m^2"
|
||||
else:
|
||||
bmi_text = "unknown"
|
||||
|
||||
# Map hypertension code to text
|
||||
if hypertension is not None:
|
||||
hypertension_text = f"{hypertension}"
|
||||
else:
|
||||
hypertension_text = "unknown"
|
||||
|
||||
# Create sentence from metadata
|
||||
if age is not None:
|
||||
sentence = f"This is the information of the user, sex: {sex_text}, age: {age_text}, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
|
||||
else:
|
||||
sentence = f"This is the information of the user, sex: {sex_text}, age: unknown, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
|
||||
|
||||
return sentence
|
||||
|
||||
|
||||
def compute_embedding_from_metadata(self,
|
||||
sexes: List[Optional[int]],
|
||||
ages: List[Optional[int]],
|
||||
heights: List[Optional[float]],
|
||||
weights: List[Optional[float]],
|
||||
systolics: List[Optional[float]],
|
||||
diastolics: List[Optional[float]],
|
||||
heart_rates: List[Optional[float]],
|
||||
bmis: List[Optional[float]],
|
||||
hypertensions: List[Optional[int]]) -> np.ndarray:
|
||||
|
||||
"""
|
||||
Generate embedding vectors from metadata using SBERT.
|
||||
|
||||
Processing Pipeline:
|
||||
1. Textualize each user's metadata into natural language
|
||||
2. Encode textual descriptions using SBERT
|
||||
3. Return fixed-size embedding vectors
|
||||
|
||||
Args:
|
||||
sexes: List of sex values (0 = Male, 1 = Female, may contain None)
|
||||
ages: List of age values (may contain None)
|
||||
heights: List of height values (may contain None)
|
||||
weights: List of weight values (may contain None)
|
||||
systolics: List of systolic blood pressure values (may contain None)
|
||||
diastolics: List of diastolic blood pressure values (may contain None)
|
||||
heart_rates: List of heart rate values (may contain None)
|
||||
bmis: List of body mass index values (may contain None)
|
||||
hypertensions: List of hypertension values (0 = Normal, 1 = Prehypertension,
|
||||
2 = Stage 1 hypertension, 3 = Stage 2 hypertension, may contain None)
|
||||
|
||||
Returns:
|
||||
Embedding array of shape (batch_size, embedding_dim)
|
||||
For all-MiniLM-L6-v2: (batch_size, 384)
|
||||
"""
|
||||
# Convert metadata to text sentences
|
||||
text_samples = []
|
||||
for sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension in zip(sexes, ages, heights, weights, systolics, diastolics, heart_rates, bmis, hypertensions):
|
||||
text_samples.append(self.textualize_metadata_ppg_bp(sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension))
|
||||
|
||||
# Encode text descriptions using SBERT
|
||||
# Returns numpy array of shape (batch_size, 384)
|
||||
embeddings = self.model.encode(text_samples)
|
||||
|
||||
return embeddings
|
||||
|
||||
def extract_embeddings(
|
||||
self,
|
||||
data_root: str,
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
batch_size: int = 32,
|
||||
label: Optional[str] = None
|
||||
) -> Dataset:
|
||||
"""
|
||||
Extract embeddings from user metadata for all sessions.
|
||||
|
||||
Iterates through all user/session combinations, loads metadata for each user,
|
||||
generates embeddings from metadata sentences, and aggregates results.
|
||||
Can process a single label or all labels.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
subject_path: Path to SC-subjects.xls file with metadata
|
||||
batch_size: Number of samples to process together (for batching embeddings)
|
||||
Larger = faster but more memory.
|
||||
32 is a good balance for most systems.
|
||||
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM").
|
||||
If None or "all", processes all labels.
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with columns:
|
||||
- user_id, idx, label (metadata from original data)
|
||||
- embedding (384-dim vector from all-MiniLM-L6-v2 based on metadata)
|
||||
"""
|
||||
# Load subject metadata
|
||||
print(f"[INFO] Loading subject metadata from: {subject_path}")
|
||||
subject_metadata = load_subject_metadata(subject_path)
|
||||
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
|
||||
|
||||
|
||||
all_embeddings = []
|
||||
all_user_ids = []
|
||||
# all_session_ids = []
|
||||
all_idxs = []
|
||||
all_labels = []
|
||||
all_sexes = []
|
||||
all_ages = []
|
||||
all_heights = []
|
||||
all_weights = []
|
||||
all_systolics = []
|
||||
all_diastolics = []
|
||||
all_heart_rates = []
|
||||
all_bmis = []
|
||||
all_hypertensions = []
|
||||
|
||||
# Collect metadata for all samples first
|
||||
for user_id, session_path in session_paths:
|
||||
# Load HuggingFace dataset from disk
|
||||
dataset = load_from_disk(session_path)
|
||||
# Shuffle dataset for randomness
|
||||
dataset = dataset.shuffle(seed=0)
|
||||
# Filter by sleep stage label if specified
|
||||
if label is not None and label != "all":
|
||||
dataset = dataset.filter(lambda x: x["label"] == label)
|
||||
num_samples = len(dataset)
|
||||
|
||||
if num_samples == 0:
|
||||
continue
|
||||
|
||||
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
|
||||
|
||||
# Get metadata for this user
|
||||
# Convert user_id to string format that matches metadata keys
|
||||
user_id_str = str(int(user_id))
|
||||
try:
|
||||
sex = subject_metadata[user_id_str]["sex"]
|
||||
age = subject_metadata[user_id_str]["age"]
|
||||
height = subject_metadata[user_id_str]["height"]
|
||||
weight = subject_metadata[user_id_str]["weight"]
|
||||
sbp = subject_metadata[user_id_str]["sbp"]
|
||||
dbp = subject_metadata[user_id_str]["dbp"]
|
||||
heart_rate = subject_metadata[user_id_str]["heart_rate"]
|
||||
bmi = subject_metadata[user_id_str]["bmi"]
|
||||
hypertension = subject_metadata[user_id_str]["hypertension"]
|
||||
|
||||
|
||||
except KeyError:
|
||||
sex = None
|
||||
age = None
|
||||
height = None
|
||||
weight = None
|
||||
sbp = None
|
||||
dbp = None
|
||||
heart_rate = None
|
||||
bmi = None
|
||||
hypertension = None
|
||||
print(f"[WARN] No metadata found for user_id: {user_id_str}")
|
||||
|
||||
# Collect all samples for this user/session
|
||||
for i in range(num_samples):
|
||||
all_user_ids.append(str(dataset["user_id"][i]))
|
||||
all_session_ids.append(str(dataset["session_id"][i]))
|
||||
all_idxs.append(int(dataset["idx"][i]))
|
||||
all_labels.append(str(dataset["label"][i]))
|
||||
all_ages.append(age)
|
||||
all_sexes.append(sex)
|
||||
|
||||
# Generate embeddings from metadata in batches
|
||||
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
|
||||
for batch_start in range(0, len(all_ages), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(all_ages))
|
||||
|
||||
batch_sexes = all_sexes[batch_start:batch_end]
|
||||
batch_ages = all_ages[batch_start:batch_end]
|
||||
|
||||
# Compute embeddings from metadata
|
||||
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
|
||||
|
||||
# Collect embeddings
|
||||
for i in range(embeddings.shape[0]):
|
||||
all_embeddings.append(embeddings[i].tolist())
|
||||
|
||||
# Create HuggingFace Dataset
|
||||
result_dataset = Dataset.from_dict({
|
||||
"user_id": all_user_ids,
|
||||
# "session_id": all_session_ids,
|
||||
"idx": all_idxs,
|
||||
"label": all_labels,
|
||||
"sex": all_sexes,
|
||||
"age": all_ages,
|
||||
"height": all_heights,
|
||||
"weight": all_weights,
|
||||
"systolic": all_systolics,
|
||||
"diastolic": all_diastolics,
|
||||
"heart_rate": all_heart_rates,
|
||||
"bmi": all_bmis,
|
||||
"hypertension": all_hypertensions,
|
||||
"embedding": all_embeddings,
|
||||
})
|
||||
|
||||
print(f"[INFO] Total samples: {len(result_dataset)}")
|
||||
|
||||
# Print label distribution
|
||||
if "label" in result_dataset.column_names:
|
||||
label_counts = {}
|
||||
for lbl in result_dataset["label"]:
|
||||
label_counts[lbl] = label_counts.get(lbl, 0) + 1
|
||||
print(f"[INFO] Label distribution: {label_counts}")
|
||||
|
||||
return result_dataset
|
||||
|
||||
def save_embeddings(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
output_dir: str
|
||||
) -> None:
|
||||
"""
|
||||
Save embeddings dataset to disk in HuggingFace format.
|
||||
|
||||
Args:
|
||||
dataset: HuggingFace Dataset containing embeddings and metadata
|
||||
output_dir: Directory path to save the dataset
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
dataset.save_to_disk(output_dir)
|
||||
|
||||
print(f"[DONE] Saved embeddings dataset: {output_dir}")
|
||||
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
|
||||
|
||||
@staticmethod
|
||||
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||
"""
|
||||
Load saved embeddings dataset from disk.
|
||||
|
||||
Args:
|
||||
embedding_dir: Directory path containing the saved HuggingFace dataset
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset with embeddings and metadata
|
||||
"""
|
||||
dataset = load_from_disk(embedding_dir)
|
||||
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
|
||||
return dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading Utilities
|
||||
# =============================================================================
|
||||
|
||||
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
|
||||
"""
|
||||
Load embeddings from all label subdirectories and concatenate them.
|
||||
|
||||
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
|
||||
and loads embeddings from each, then concatenates them into a single dataset.
|
||||
|
||||
Args:
|
||||
embeddings_root: Root directory containing label subdirectories
|
||||
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
|
||||
|
||||
Returns:
|
||||
Concatenated HuggingFace Dataset with all labels combined
|
||||
"""
|
||||
# Discover all label subdirectories
|
||||
label_dirs = []
|
||||
for item in os.listdir(embeddings_root):
|
||||
item_path = os.path.join(embeddings_root, item)
|
||||
if os.path.isdir(item_path):
|
||||
# Check if it's a valid HuggingFace dataset directory
|
||||
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
|
||||
label_dirs.append((item, item_path))
|
||||
|
||||
if len(label_dirs) == 0:
|
||||
raise ValueError(
|
||||
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
|
||||
)
|
||||
|
||||
label_dirs.sort() # Sort for consistent ordering
|
||||
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
|
||||
|
||||
# Load datasets from each label directory
|
||||
datasets = []
|
||||
for label_name, label_path in label_dirs:
|
||||
print(f"[INFO] Loading embeddings from: {label_path}")
|
||||
dataset = load_from_disk(label_path)
|
||||
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
|
||||
datasets.append(dataset)
|
||||
|
||||
# Concatenate all datasets
|
||||
if len(datasets) == 1:
|
||||
combined_dataset = datasets[0]
|
||||
else:
|
||||
combined_dataset = concatenate_datasets(datasets)
|
||||
|
||||
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
|
||||
|
||||
# Print label distribution
|
||||
if "label" in combined_dataset.column_names:
|
||||
label_counts = {}
|
||||
for label in combined_dataset["label"]:
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
print(f"[INFO] Label distribution: {label_counts}")
|
||||
|
||||
return combined_dataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Visualization Functions
|
||||
# =============================================================================
|
||||
|
||||
def reduce_to_2d_tsne(
|
||||
embeddings: np.ndarray,
|
||||
perplexity: float = 30.0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Reduce high-dimensional embeddings to 2D using t-SNE.
|
||||
|
||||
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
|
||||
dimensionality reduction technique that preserves local structure.
|
||||
Points that are similar in high dimensions stay close in 2D.
|
||||
|
||||
Args:
|
||||
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
|
||||
perplexity: t-SNE perplexity parameter (typically 5-50).
|
||||
Higher values consider more neighbors, creating smoother layouts.
|
||||
Rule of thumb: perplexity ~ sqrt(num_samples)
|
||||
|
||||
Returns:
|
||||
2D coordinates of shape (num_samples, 2)
|
||||
"""
|
||||
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
|
||||
|
||||
tsne = TSNE(
|
||||
n_components=2,
|
||||
random_state=0, # For reproducibility
|
||||
perplexity=perplexity,
|
||||
max_iter=1000, # Usually sufficient for convergence
|
||||
init='random',
|
||||
learning_rate='auto', # Let sklearn choose optimal learning rate
|
||||
)
|
||||
|
||||
return tsne.fit_transform(embeddings)
|
||||
|
||||
|
||||
def create_scatter_plot_by_age(
|
||||
coordinates: np.ndarray,
|
||||
ages: np.ndarray,
|
||||
title: str,
|
||||
output_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Create and save a 2D scatter plot colored by age (continuous colormap).
|
||||
|
||||
Uses a continuous colormap (viridis) to show age distribution.
|
||||
Ages are mapped to colors on a gradient scale.
|
||||
|
||||
Args:
|
||||
coordinates: 2D array of shape (num_points, 2)
|
||||
ages: Age values for each point (array, may contain None values)
|
||||
title: Plot title
|
||||
output_path: File path to save the plot (PDF vector format recommended)
|
||||
"""
|
||||
# Filter out points with missing age data
|
||||
# Convert None values to NaN for proper numpy handling
|
||||
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
|
||||
valid_mask = ~np.isnan(ages_float)
|
||||
valid_coords = coordinates[valid_mask]
|
||||
valid_ages = ages_float[valid_mask]
|
||||
|
||||
if len(valid_ages) == 0:
|
||||
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
|
||||
return
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Create scatter plot with continuous colormap
|
||||
scatter = ax.scatter(
|
||||
valid_coords[:, 0],
|
||||
valid_coords[:, 1],
|
||||
c=valid_ages,
|
||||
cmap='viridis', # Continuous colormap for granular age visualization
|
||||
s=15,
|
||||
alpha=0.7,
|
||||
edgecolors='none',
|
||||
)
|
||||
|
||||
# Add colorbar
|
||||
cbar = plt.colorbar(scatter, ax=ax)
|
||||
cbar.set_label('Age (years)', rotation=270, labelpad=20)
|
||||
|
||||
# Configure plot appearance
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Dimension 1")
|
||||
ax.set_ylabel("Dimension 2")
|
||||
|
||||
# Save figure as vector PDF
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"[DONE] Saved plot: {output_path}")
|
||||
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
|
||||
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command Line Interface
|
||||
# =============================================================================
|
||||
|
||||
class CLI:
|
||||
"""
|
||||
Command-line interface for SBERT metadata-based embedding extraction and visualization.
|
||||
|
||||
Provides two main commands:
|
||||
- extract: Generate embeddings from user metadata (age, sex)
|
||||
- plot: Visualize embeddings with t-SNE colored by age
|
||||
"""
|
||||
|
||||
def extract(
|
||||
self,
|
||||
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
|
||||
subject_path: str = SUBJECT_PATH,
|
||||
out_dir: str = "./embeddings/all_labels",
|
||||
batch_size: int = 32,
|
||||
label: str = None
|
||||
) -> None:
|
||||
"""
|
||||
Extract embeddings from user metadata.
|
||||
|
||||
Args:
|
||||
data_root: Root directory containing user/session data folders
|
||||
subject_path: Path to SC-subjects.xls file with metadata
|
||||
out_dir: Output directory for HuggingFace dataset
|
||||
batch_size: Batch size for inference (default: 32)
|
||||
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
|
||||
If None or 'all', processes all labels (default: None for all labels)
|
||||
"""
|
||||
embedder = SBERT_Metadata()
|
||||
dataset = embedder.extract_embeddings(data_root, subject_path, batch_size, label)
|
||||
embedder.save_embeddings(dataset, out_dir)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
emb_dir: str = None,
|
||||
embeddings_root: str = "./embeddings",
|
||||
out_dir: str = "./plots/all_labels",
|
||||
perplexity: float = 10.0,
|
||||
users: str = None,
|
||||
num_users: int = 0,
|
||||
labels: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize embeddings with t-SNE, colored by age.
|
||||
|
||||
Can load from either a single label directory or all label directories.
|
||||
|
||||
Args:
|
||||
emb_dir: Single directory containing the HuggingFace embeddings dataset
|
||||
(e.g., "./embeddings/REM"). If provided, only this directory is used.
|
||||
embeddings_root: Root directory containing label subdirectories
|
||||
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
|
||||
Used only if emb_dir is not provided.
|
||||
out_dir: Output directory for visualization plots (PDF)
|
||||
perplexity: t-SNE perplexity parameter (default: 10.0)
|
||||
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||
num_users: Include only first N users, 0 = all (default: 0)
|
||||
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||
This filters the already-loaded data, not which directories to load.
|
||||
"""
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Load embeddings: either from single directory or all label directories
|
||||
if emb_dir is not None:
|
||||
# Load from single directory
|
||||
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
|
||||
dataset = SBERT_Metadata.load_embeddings(emb_dir)
|
||||
else:
|
||||
# Load from all label directories
|
||||
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
|
||||
dataset = load_embeddings_from_all_labels(embeddings_root)
|
||||
|
||||
# Apply user filtering
|
||||
if users:
|
||||
user_list = [u.strip() for u in users.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
|
||||
print(f"[INFO] Filtered to users: {user_list}")
|
||||
|
||||
elif num_users > 0:
|
||||
all_users = sorted(set(dataset["user_id"]))
|
||||
selected_users = all_users[:num_users]
|
||||
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||
|
||||
# Filter by sleep stage labels
|
||||
if labels:
|
||||
label_list = [l.strip() for l in labels.split(",")]
|
||||
dataset = dataset.filter(lambda x: x["label"] in label_list)
|
||||
print(f"[INFO] Filtered to labels: {label_list}")
|
||||
|
||||
print(f"[INFO] Total samples: {len(dataset)}")
|
||||
|
||||
# Extract embeddings as numpy array for t-SNE
|
||||
embeddings = np.array(dataset["embedding"])
|
||||
|
||||
# Extract ages for coloring
|
||||
ages = np.array(dataset["age"])
|
||||
|
||||
# Reduce to 2D with t-SNE
|
||||
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||
|
||||
# Generate visualization colored by age
|
||||
create_scatter_plot_by_age(
|
||||
coordinates_2d,
|
||||
ages,
|
||||
"t-SNE Visualization (Colored by Age)",
|
||||
os.path.join(out_dir, "tsne_by_age.pdf")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(CLI)
|
||||
Reference in New Issue
Block a user