Added SBERT metadata embedding extraction and visualization for PPG-BP dataset

This commit is contained in:
Gautham Krishna Gudur
2026-02-03 20:02:59 -06:00
committed by GitHub
parent 60ae0a5941
commit 056327731e

View File

@@ -0,0 +1,748 @@
"""
SBERT Metadata-Based Embedding Extraction and Visualization Pipeline
This module provides functionality to
1. Extract embeddings from user metadata (age and sex) using SBERT (Sentence-BERT)
2. Visualize embeddings using dimensionality reduction (t-SNE) colored by age
SBERT Overview:
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
and triplet network structures to derive semantically meaningful sentence
embeddings. It's designed for semantic similarity tasks and produces
fixed-size dense vector representations of text.
Key Features:
- Semantic understanding: Captures meaning rather than just word presence
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
- Efficient: Optimized for sentence-level tasks
Embedding Strategy:
Instead of using time series features, we create textual descriptions based on
user metadata (age and sex). This approach allows us to capture user-level
characteristics in the embedding space.
Processing Pipeline:
1. Load user metadata from XLS file (age, sex)
2. Textualization: Convert metadata to natural language description
3. SBERT encoding: Generate 384-dimensional embeddings
4. Visualization: t-SNE with continuous age coloring
Usage:
# Extract embeddings from metadata for all labels
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/all_labels
# Extract embeddings for a single label
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/REM --label REM
# Visualize with t-SNE from all label directories (colored by age)
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels
# Visualize from a single label directory
python gen_plot.py plot --emb_dir ./embeddings/REM --out_dir ./plots/REM
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset, concatenate_datasets
from fire import Fire
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
# =============================================================================
# Constants
# =============================================================================
SUBJECT_PATH = "LLM_Health/tsllm_personalization_icl/analysis/user_similarity/sbert_metadata_ppgbp/PPGBP_metadata.xlsx"
# =============================================================================
# Metadata Loading
# =============================================================================
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
"""
Load subject metadata from XLS file.
The XLS file contains subject information including:
- subject: Subject ID (e.g., "SC4001", "SC4002")
- age: Age of the subject
- sex (F=1): Sex (1 = Female, 0 = Male)
Args:
subject_path: Path to the SC-subjects.xls file
Returns:
Dictionary mapping subject IDs to metadata dictionaries
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
"""
df = pd.read_excel(subject_path, header=1)
subject_info = {}
for index, row in df.iterrows():
subject_id = str(row["subject_ID"]).strip()
subject_info[subject_id] = {
"sex": str(row["Sex(M/F)"]).strip() if pd.notna(row["Sex(M/F)"]) else None,
"age": int(row["Age(year)"]) if pd.notna(row["Age(year)"]) else None,
"height": int(row["Height(cm)"]) if pd.notna(row["Height(cm)"]) else None,
"weight": int(row["Weight(kg)"]) if pd.notna(row["Weight(kg)"]) else None,
"sbp": int(row["Systolic Blood Pressure(mmHg)"])
if pd.notna(row["Systolic Blood Pressure(mmHg)"]) else None,
"dbp": int(row["Diastolic Blood Pressure(mmHg)"])
if pd.notna(row["Diastolic Blood Pressure(mmHg)"]) else None,
"hr": int(row["Heart Rate(b/m)"])
if pd.notna(row["Heart Rate(b/m)"]) else None,
"bmi": float(row["BMI(kg/m^2)"]) if pd.notna(row["BMI(kg/m^2)"]) else None,
"hypertension": str(row["Hypertension"]) if pd.notna(row["Hypertension"]) else None,
}
return subject_info
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class SBERT_Metadata:
"""
Extracts fixed-dimensional embeddings from user metadata using SBERT.
Uses Sentence-BERT to convert textualized metadata descriptions into dense
vector representations. The textualization process converts sex, age, height,
weight, sbp, dbp, heart rate, bmi, hypertension information into natural language,
which SBERT then encodes semantically.
Architecture:
User Metadata → Textualization → SBERT Encoder → Embedding
Processing Pipeline:
1. Load user metadata (sex, age, height, weight, sbp, dbp, heart rate, bmi,
hypertension) from XLS file
2. Textualize: Convert metadata to natural language description
3. SBERT encoding: Generate 384-dimensional semantic embeddings
4. Output: Fixed-size embedding vector per user
Model: all-MiniLM-L6-v2
- Lightweight BERT variant optimized for sentence embeddings
- 384-dimensional output embeddings
- Fast inference with good semantic understanding
Attributes:
model: SentenceTransformer instance for encoding textualized metadata
"""
def __init__(self):
"""
Initialize the SBERT embedder with pre-trained model.
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
optimized for sentence similarity tasks.
"""
self.model = SentenceTransformer("all-MiniLM-L6-v2")
def textualize_metadata_ppg_bp(self,
sex: Optional[str],
age: Optional[int],
height: Optional[int],
weight: Optional[int],
sbp: Optional[int],
dbp: Optional[int],
heart_rate: Optional[int],
bmi: Optional[float],
hypertension: Optional[str]) -> str:
"""
Convert user metadata (age, sex, height, weight, sbp, dbp, heart rate, bmi, hypertension) into a natural language description.
This textualization step is crucial for SBERT, which expects text input.
The description provides physiological and demographic information in a structured format.
Args:
sex: Sex of the user (0 = Male, 1 = Female, may be None)
age: Age of the user (integer, may be None)
height_cm: Height in centimeters, (integer, may be None)
weight_kg: Weight in kilograms, (integer, may be None)
sbp_mmHg: Systolic blood pressure in mmHg, (integer, may be None)
dbp_mmHg: Diastolic blood pressure in mmHg, (integer, may be None)
heart_rate_bpm: Heart rate in beats per minute, (integer, may be None)
bmi: Body mass index (kg/m^2), (float, may be None)
hypertension: Hypertension status (0 = Normal, 1 = Prehypertension, 2 = Stage 1 hypertension, 3 = Stage 2 hypertension, may be None)
Returns:
Natural language string describing the user metadata
"""
# Map sex code to text
if sex is not None:
sex_text = "Female" if sex == 1 else "Male"
else:
sex_text = "Unknown"
# Map age code to text
if age is not None:
age_text = f"{age}"
else:
age_text = "unknown"
# Map height code to text
if height is not None:
height_text = f"{height} cm"
else:
height_text = "unknown"
# Map weight code to text
if weight is not None:
weight_text = f"{weight} kg"
else:
weight_text = "unknown"
# Map sbp code to text
if sbp is not None:
sbp_text = f"{sbp} mmHg"
else:
sbp_text = "unknown"
# Map dbp code to text
if dbp is not None:
dbp_text = f"{dbp} mmHg"
else:
dbp_text = "unknown"
# Map heart rate code to text
if heart_rate is not None:
heart_rate_text = f"{heart_rate} bpm"
else:
heart_rate_text = "unknown"
# Map bmi code to text
if bmi is not None:
bmi_text = f"{bmi} kg/m^2"
else:
bmi_text = "unknown"
# Map hypertension code to text
if hypertension is not None:
hypertension_text = f"{hypertension}"
else:
hypertension_text = "unknown"
# Create sentence from metadata
if age is not None:
sentence = f"This is the information of the user, sex: {sex_text}, age: {age_text}, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
else:
sentence = f"This is the information of the user, sex: {sex_text}, age: unknown, height: {height_text}, weight: {weight_text}, sbp: {sbp_text}, dbp: {dbp_text}, heart rate: {heart_rate_text} bpm, bmi: {bmi_text}, hypertension: {hypertension_text}."
return sentence
def compute_embedding_from_metadata(self,
sexes: List[Optional[int]],
ages: List[Optional[int]],
heights: List[Optional[float]],
weights: List[Optional[float]],
systolics: List[Optional[float]],
diastolics: List[Optional[float]],
heart_rates: List[Optional[float]],
bmis: List[Optional[float]],
hypertensions: List[Optional[int]]) -> np.ndarray:
"""
Generate embedding vectors from metadata using SBERT.
Processing Pipeline:
1. Textualize each user's metadata into natural language
2. Encode textual descriptions using SBERT
3. Return fixed-size embedding vectors
Args:
sexes: List of sex values (0 = Male, 1 = Female, may contain None)
ages: List of age values (may contain None)
heights: List of height values (may contain None)
weights: List of weight values (may contain None)
systolics: List of systolic blood pressure values (may contain None)
diastolics: List of diastolic blood pressure values (may contain None)
heart_rates: List of heart rate values (may contain None)
bmis: List of body mass index values (may contain None)
hypertensions: List of hypertension values (0 = Normal, 1 = Prehypertension,
2 = Stage 1 hypertension, 3 = Stage 2 hypertension, may contain None)
Returns:
Embedding array of shape (batch_size, embedding_dim)
For all-MiniLM-L6-v2: (batch_size, 384)
"""
# Convert metadata to text sentences
text_samples = []
for sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension in zip(sexes, ages, heights, weights, systolics, diastolics, heart_rates, bmis, hypertensions):
text_samples.append(self.textualize_metadata_ppg_bp(sex, age, height, weight, systolic, diastolic, heart_rate, bmi, hypertension))
# Encode text descriptions using SBERT
# Returns numpy array of shape (batch_size, 384)
embeddings = self.model.encode(text_samples)
return embeddings
def extract_embeddings(
self,
data_root: str,
subject_path: str = SUBJECT_PATH,
batch_size: int = 32,
label: Optional[str] = None
) -> Dataset:
"""
Extract embeddings from user metadata for all sessions.
Iterates through all user/session combinations, loads metadata for each user,
generates embeddings from metadata sentences, and aggregates results.
Can process a single label or all labels.
Args:
data_root: Root directory containing user/session data folders
subject_path: Path to SC-subjects.xls file with metadata
batch_size: Number of samples to process together (for batching embeddings)
Larger = faster but more memory.
32 is a good balance for most systems.
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM").
If None or "all", processes all labels.
Returns:
HuggingFace Dataset with columns:
- user_id, idx, label (metadata from original data)
- embedding (384-dim vector from all-MiniLM-L6-v2 based on metadata)
"""
# Load subject metadata
print(f"[INFO] Loading subject metadata from: {subject_path}")
subject_metadata = load_subject_metadata(subject_path)
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
all_embeddings = []
all_user_ids = []
# all_session_ids = []
all_idxs = []
all_labels = []
all_sexes = []
all_ages = []
all_heights = []
all_weights = []
all_systolics = []
all_diastolics = []
all_heart_rates = []
all_bmis = []
all_hypertensions = []
# Collect metadata for all samples first
for user_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# Shuffle dataset for randomness
dataset = dataset.shuffle(seed=0)
# Filter by sleep stage label if specified
if label is not None and label != "all":
dataset = dataset.filter(lambda x: x["label"] == label)
num_samples = len(dataset)
if num_samples == 0:
continue
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Get metadata for this user
# Convert user_id to string format that matches metadata keys
user_id_str = str(int(user_id))
try:
sex = subject_metadata[user_id_str]["sex"]
age = subject_metadata[user_id_str]["age"]
height = subject_metadata[user_id_str]["height"]
weight = subject_metadata[user_id_str]["weight"]
sbp = subject_metadata[user_id_str]["sbp"]
dbp = subject_metadata[user_id_str]["dbp"]
heart_rate = subject_metadata[user_id_str]["heart_rate"]
bmi = subject_metadata[user_id_str]["bmi"]
hypertension = subject_metadata[user_id_str]["hypertension"]
except KeyError:
sex = None
age = None
height = None
weight = None
sbp = None
dbp = None
heart_rate = None
bmi = None
hypertension = None
print(f"[WARN] No metadata found for user_id: {user_id_str}")
# Collect all samples for this user/session
for i in range(num_samples):
all_user_ids.append(str(dataset["user_id"][i]))
all_session_ids.append(str(dataset["session_id"][i]))
all_idxs.append(int(dataset["idx"][i]))
all_labels.append(str(dataset["label"][i]))
all_ages.append(age)
all_sexes.append(sex)
# Generate embeddings from metadata in batches
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
for batch_start in range(0, len(all_ages), batch_size):
batch_end = min(batch_start + batch_size, len(all_ages))
batch_sexes = all_sexes[batch_start:batch_end]
batch_ages = all_ages[batch_start:batch_end]
# Compute embeddings from metadata
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
# Collect embeddings
for i in range(embeddings.shape[0]):
all_embeddings.append(embeddings[i].tolist())
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
# "session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"sex": all_sexes,
"age": all_ages,
"height": all_heights,
"weight": all_weights,
"systolic": all_systolics,
"diastolic": all_diastolics,
"heart_rate": all_heart_rates,
"bmi": all_bmis,
"hypertension": all_hypertensions,
"embedding": all_embeddings,
})
print(f"[INFO] Total samples: {len(result_dataset)}")
# Print label distribution
if "label" in result_dataset.column_names:
label_counts = {}
for lbl in result_dataset["label"]:
label_counts[lbl] = label_counts.get(lbl, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return result_dataset
def save_embeddings(
self,
dataset: Dataset,
output_dir: str
) -> None:
"""
Save embeddings dataset to disk in HuggingFace format.
Args:
dataset: HuggingFace Dataset containing embeddings and metadata
output_dir: Directory path to save the dataset
"""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load saved embeddings dataset from disk.
Args:
embedding_dir: Directory path containing the saved HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Data Loading Utilities
# =============================================================================
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
"""
Load embeddings from all label subdirectories and concatenate them.
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
and loads embeddings from each, then concatenates them into a single dataset.
Args:
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
Returns:
Concatenated HuggingFace Dataset with all labels combined
"""
# Discover all label subdirectories
label_dirs = []
for item in os.listdir(embeddings_root):
item_path = os.path.join(embeddings_root, item)
if os.path.isdir(item_path):
# Check if it's a valid HuggingFace dataset directory
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
label_dirs.append((item, item_path))
if len(label_dirs) == 0:
raise ValueError(
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
)
label_dirs.sort() # Sort for consistent ordering
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
# Load datasets from each label directory
datasets = []
for label_name, label_path in label_dirs:
print(f"[INFO] Loading embeddings from: {label_path}")
dataset = load_from_disk(label_path)
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return combined_dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
dimensionality reduction technique that preserves local structure.
Points that are similar in high dimensions stay close in 2D.
Args:
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity parameter (typically 5-50).
Higher values consider more neighbors, creating smoother layouts.
Rule of thumb: perplexity ~ sqrt(num_samples)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0, # For reproducibility
perplexity=perplexity,
max_iter=1000, # Usually sufficient for convergence
init='random',
learning_rate='auto', # Let sklearn choose optimal learning rate
)
return tsne.fit_transform(embeddings)
def create_scatter_plot_by_age(
coordinates: np.ndarray,
ages: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by age (continuous colormap).
Uses a continuous colormap (viridis) to show age distribution.
Ages are mapped to colors on a gradient scale.
Args:
coordinates: 2D array of shape (num_points, 2)
ages: Age values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing age data
# Convert None values to NaN for proper numpy handling
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
valid_mask = ~np.isnan(ages_float)
valid_coords = coordinates[valid_mask]
valid_ages = ages_float[valid_mask]
if len(valid_ages) == 0:
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
return
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Create scatter plot with continuous colormap
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_ages,
cmap='viridis', # Continuous colormap for granular age visualization
s=15,
alpha=0.7,
edgecolors='none',
)
# Add colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Age (years)', rotation=270, labelpad=20)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT metadata-based embedding extraction and visualization.
Provides two main commands:
- extract: Generate embeddings from user metadata (age, sex)
- plot: Visualize embeddings with t-SNE colored by age
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
subject_path: str = SUBJECT_PATH,
out_dir: str = "./embeddings/all_labels",
batch_size: int = 32,
label: str = None
) -> None:
"""
Extract embeddings from user metadata.
Args:
data_root: Root directory containing user/session data folders
subject_path: Path to SC-subjects.xls file with metadata
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
If None or 'all', processes all labels (default: None for all labels)
"""
embedder = SBERT_Metadata()
dataset = embedder.extract_embeddings(data_root, subject_path, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = None,
embeddings_root: str = "./embeddings",
out_dir: str = "./plots/all_labels",
perplexity: float = 10.0,
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE, colored by age.
Can load from either a single label directory or all label directories.
Args:
emb_dir: Single directory containing the HuggingFace embeddings dataset
(e.g., "./embeddings/REM"). If provided, only this directory is used.
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
Used only if emb_dir is not provided.
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 10.0)
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
This filters the already-loaded data, not which directories to load.
"""
os.makedirs(out_dir, exist_ok=True)
# Load embeddings: either from single directory or all label directories
if emb_dir is not None:
# Load from single directory
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
dataset = SBERT_Metadata.load_embeddings(emb_dir)
else:
# Load from all label directories
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
dataset = load_embeddings_from_all_labels(embeddings_root)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Extract ages for coloring
ages = np.array(dataset["age"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualization colored by age
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
if __name__ == "__main__":
Fire(CLI)