implemented SBERT-based analysis

This commit is contained in:
Hyungjun Yoon
2026-01-19 15:10:17 +09:00
parent eed807a2a2
commit 53c3cc9d24
4 changed files with 1191 additions and 7 deletions

10
.gitignore vendored
View File

@@ -1,8 +1,8 @@
.pdf *.pdf
.txt *.txt
.csv *.csv
.arrow *.arrow
.json *.json
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

View File

@@ -416,9 +416,9 @@ class CLI:
def extract( def extract(
self, self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF", data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
out_dir: str = "./embeddings/W", out_dir: str = "./embeddings/REM",
batch_size: int = 32, batch_size: int = 32,
label: str = "W" label: str = "REM"
) -> None: ) -> None:
""" """
Extract embeddings from time series data. Extract embeddings from time series data.

View File

@@ -0,0 +1,529 @@
"""
SBERT Embedding Extraction and Visualization with Metadata (Age and Sex)
This module provides functionality to:
1. Extract SBERT embeddings from time series features
2. Visualize embeddings colored by subject metadata (age and sex) instead of user IDs
Features:
1. Extract embeddings from time series features using SBERT
2. Load embeddings from HuggingFace dataset
3. Load subject metadata from XLS file
4. Map user IDs to age and sex information
5. Visualize embeddings with t-SNE colored by age (continuous) and sex (categorical)
Usage:
# Extract embeddings for all labels
python gen_plot_metadata.py extract --data_root /path/to/data --out_dir ./embeddings/all_labels
# Extract embeddings for a single label
python gen_plot_metadata.py extract --data_root /path/to/data --out_dir ./embeddings/W --label W
# Visualize by age from single label directory
python gen_plot_metadata.py plot --emb_dir ./embeddings/W --out_dir ./plots/W --color_by age
# Visualize by age from all label directories (W, REM, N1, N2, N3)
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by age
# Visualize by sex from all labels
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by sex
# Create both age and sex plots from all labels
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by both
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset, concatenate_datasets
from fire import Fire
from sklearn.manifold import TSNE
from gen_plot import SBERT, reduce_to_2d_tsne
# =============================================================================
# Constants
# =============================================================================
SUBJECT_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/SC-subjects.xls"
# =============================================================================
# Metadata Loading
# =============================================================================
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
"""
Load subject metadata from XLS file.
The XLS file contains subject information including:
- subject: Subject ID (e.g., "SC4001", "SC4002")
- age: Age of the subject
- sex (F=1): Sex (1 = Female, 0 = Male)
Args:
subject_path: Path to the SC-subjects.xls file
Returns:
Dictionary mapping subject IDs to metadata dictionaries
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
"""
df = pd.read_excel(subject_path)
subject_info = {}
for index, row in df.iterrows():
subject_id = str(row["subject"]).strip()
subject_info[subject_id] = {
"age": int(row["age"]) if pd.notna(row["age"]) else None,
"sex": int(row["sex (F=1)"]) if pd.notna(row["sex (F=1)"]) else None,
}
return subject_info
def map_user_ids_to_metadata(
user_ids: np.ndarray,
subject_metadata: Dict[str, Dict[str, Any]]
) -> tuple:
"""
Map user IDs from dataset to age and sex metadata.
User IDs in the dataset are typically 2-digit codes (e.g., "40", "41").
Subject IDs in the metadata file are typically 4-character codes (e.g., "SC40", "SC41").
We need to match them appropriately.
Args:
user_ids: Array of user IDs from the dataset
subject_metadata: Dictionary of subject metadata loaded from XLS file
Returns:
Tuple of (ages, sexes) as numpy arrays
Missing values are set to None
"""
ages = []
sexes = []
for user_id in user_ids:
age = subject_metadata[user_id]["age"]
sex = subject_metadata[user_id]["sex"]
ages.append(age)
sexes.append(sex)
return np.array(ages), np.array(sexes)
# =============================================================================
# Visualization Functions
# =============================================================================
def create_scatter_plot_by_age(
coordinates: np.ndarray,
ages: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by age (continuous colormap).
Uses a continuous colormap (viridis) to show age distribution.
Ages are mapped to colors on a gradient scale.
Args:
coordinates: 2D array of shape (num_points, 2)
ages: Age values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing age data
# Convert None values to NaN for proper numpy handling
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
valid_mask = ~np.isnan(ages_float)
valid_coords = coordinates[valid_mask]
valid_ages = ages_float[valid_mask]
if len(valid_ages) == 0:
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
return
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Create scatter plot with continuous colormap
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_ages,
cmap='viridis', # Continuous colormap for granular age visualization
s=15,
alpha=0.7,
edgecolors='none',
)
# Add colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Age (years)', rotation=270, labelpad=20)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
def create_scatter_plot_by_sex(
coordinates: np.ndarray,
sexes: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by sex (categorical).
Uses discrete colors for different sex categories.
Sex encoding: 1 = Female, 0 = Male
Args:
coordinates: 2D array of shape (num_points, 2)
sexes: Sex values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing sex data
# Convert None values to NaN for proper numpy handling
sexes_float = np.array([float(s) if s is not None else np.nan for s in sexes])
valid_mask = ~np.isnan(sexes_float)
valid_coords = coordinates[valid_mask]
valid_sexes = sexes_float[valid_mask].astype(int)
if len(valid_sexes) == 0:
print(f"[WARN] No valid sex data found. Skipping plot: {output_path}")
return
# Map sex codes to labels
sex_labels = {0: "Male", 1: "Female"}
unique_sexes = sorted(set(valid_sexes))
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Plot each category separately for proper legend
colors = ['steelblue', 'coral'] # Blue for Male, Coral for Female
for idx, sex_code in enumerate(unique_sexes):
mask = valid_sexes == sex_code
ax.scatter(
valid_coords[mask, 0],
valid_coords[mask, 1],
c=colors[sex_code % len(colors)],
s=15,
label=sex_labels.get(sex_code, f"Sex {sex_code}"),
alpha=0.7,
)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
ax.legend(loc='best', markerscale=2)
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
sex_counts = {sex_labels.get(s, f"Sex {s}"): (valid_sexes == s).sum() for s in unique_sexes}
print(f"[INFO] Sex distribution: {sex_counts}")
print(f"[INFO] Points with valid sex: {len(valid_sexes)}/{len(sexes)}")
# =============================================================================
# Embedding Extraction Utilities
# =============================================================================
def extract_embeddings_for_all_labels(
data_root: str,
out_dir: str,
batch_size: int = 32
) -> Dataset:
"""
Extract embeddings for all sleep stage labels and combine them.
Extracts embeddings for each label (W, REM, N1, N2, N3) separately,
then concatenates them into a single dataset.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for the combined HuggingFace dataset
batch_size: Batch size for inference (default: 32)
Returns:
Combined HuggingFace Dataset with all labels
"""
embedder = SBERT()
all_labels = ["W", "REM", "N1", "N2", "N3"]
datasets = []
for label in all_labels:
print(f"\n[INFO] Extracting embeddings for label: {label}")
dataset = embedder.extract_embeddings(data_root, batch_size, label)
print(f"[INFO] Extracted {len(dataset)} samples for label {label}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"\n[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
# Save combined dataset
embedder.save_embeddings(combined_dataset, out_dir)
return combined_dataset
# =============================================================================
# Data Loading Utilities
# =============================================================================
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
"""
Load embeddings from all label subdirectories and concatenate them.
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
and loads embeddings from each, then concatenates them into a single dataset.
Args:
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
Returns:
Concatenated HuggingFace Dataset with all labels combined
"""
# Discover all label subdirectories
label_dirs = []
for item in os.listdir(embeddings_root):
item_path = os.path.join(embeddings_root, item)
if os.path.isdir(item_path):
# Check if it's a valid HuggingFace dataset directory
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
label_dirs.append((item, item_path))
if len(label_dirs) == 0:
raise ValueError(
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
)
label_dirs.sort() # Sort for consistent ordering
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
# Load datasets from each label directory
datasets = []
for label_name, label_path in label_dirs:
print(f"[INFO] Loading embeddings from: {label_path}")
dataset = load_from_disk(label_path)
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return combined_dataset
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT embedding extraction and visualization with metadata.
Provides:
- extract: Generate embeddings from time series features
- plot: Visualize embeddings colored by age or sex instead of user ID
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
out_dir: str = "./embeddings/all_labels",
batch_size: int = 32,
label: Optional[str] = None
) -> None:
"""
Extract embeddings from time series features using SBERT.
Can extract for a single label or all labels at once.
Args:
data_root: Root directory containing user/session data folders
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
If None or 'all', extracts for all labels (default: None for all labels)
"""
embedder = SBERT()
if label is None or label == "all":
# Extract for all labels
print(f"[INFO] Extracting embeddings for all labels")
extract_embeddings_for_all_labels(data_root, out_dir, batch_size)
else:
# Extract for single label
print(f"[INFO] Extracting embeddings for label: {label}")
dataset = embedder.extract_embeddings(data_root, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = None,
embeddings_root: str = "./embeddings",
out_dir: str = "./plots/all_labels",
subject_path: str = SUBJECT_PATH,
perplexity: float = 10.0,
color_by: str = "age",
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE, colored by age or sex.
Can load from either a single label directory or all label directories.
Args:
emb_dir: Single directory containing the HuggingFace embeddings dataset
(e.g., "./embeddings/W"). If provided, only this directory is used.
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
Used only if emb_dir is not provided.
out_dir: Output directory for visualization plots (PDF)
subject_path: Path to SC-subjects.xls file with metadata
perplexity: t-SNE perplexity parameter (default: 10.0)
color_by: What to color by - 'age', 'sex', or 'both' (default: 'age')
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
This filters the already-loaded data, not which directories to load.
"""
# Validate color_by argument
if color_by not in ["age", "sex", "both", "all"]:
raise ValueError(f"Invalid color_by: {color_by}. Use 'age', 'sex', or 'both'.")
os.makedirs(out_dir, exist_ok=True)
# Load embeddings: either from single directory or all label directories
if emb_dir is not None:
# Load from single directory
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
dataset = SBERT.load_embeddings(emb_dir)
else:
# Load from all label directories
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
dataset = load_embeddings_from_all_labels(embeddings_root)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Load subject metadata
print(f"[INFO] Loading subject metadata from: {subject_path}")
subject_metadata = load_subject_metadata(subject_path)
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
# Map user IDs to metadata
user_ids = np.array([str(uid) for uid in dataset["user_id"]])
user_ids_ = [str(int(uid)) for uid in user_ids]
ages, sexes = map_user_ids_to_metadata(user_ids_, subject_metadata)
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualizations based on color_by parameter
if color_by == "age":
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
elif color_by == "sex":
create_scatter_plot_by_sex(
coordinates_2d,
sexes,
"t-SNE Visualization (Colored by Sex)",
os.path.join(out_dir, "tsne_by_sex.pdf")
)
elif color_by == "both" or color_by == "all":
# Create both plots
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
create_scatter_plot_by_sex(
coordinates_2d,
sexes,
"t-SNE Visualization (Colored by Sex)",
os.path.join(out_dir, "tsne_by_sex.pdf")
)
if __name__ == "__main__":
Fire(CLI)

View File

@@ -0,0 +1,655 @@
"""
SBERT Metadata-Based Embedding Extraction and Visualization Pipeline
This module provides functionality to
1. Extract embeddings from user metadata (age and sex) using SBERT (Sentence-BERT)
2. Visualize embeddings using dimensionality reduction (t-SNE) colored by age
SBERT Overview:
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
and triplet network structures to derive semantically meaningful sentence
embeddings. It's designed for semantic similarity tasks and produces
fixed-size dense vector representations of text.
Key Features:
- Semantic understanding: Captures meaning rather than just word presence
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
- Efficient: Optimized for sentence-level tasks
Embedding Strategy:
Instead of using time series features, we create textual descriptions based on
user metadata (age and sex). This approach allows us to capture user-level
characteristics in the embedding space.
Processing Pipeline:
1. Load user metadata from XLS file (age, sex)
2. Textualization: Convert metadata to natural language description
3. SBERT encoding: Generate 384-dimensional embeddings
4. Visualization: t-SNE with continuous age coloring
Usage:
# Extract embeddings from metadata for all labels
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/all_labels
# Extract embeddings for a single label
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/REM --label REM
# Visualize with t-SNE from all label directories (colored by age)
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels
# Visualize from a single label directory
python gen_plot.py plot --emb_dir ./embeddings/REM --out_dir ./plots/REM
Author: Sumin Im (NMSL Undergraduate Researcher)
Date: 2026-01-09
"""
import os
from glob import glob
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_from_disk, Dataset, concatenate_datasets
from fire import Fire
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
# =============================================================================
# Constants
# =============================================================================
SUBJECT_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/SC-subjects.xls"
# =============================================================================
# Metadata Loading
# =============================================================================
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
"""
Load subject metadata from XLS file.
The XLS file contains subject information including:
- subject: Subject ID (e.g., "SC4001", "SC4002")
- age: Age of the subject
- sex (F=1): Sex (1 = Female, 0 = Male)
Args:
subject_path: Path to the SC-subjects.xls file
Returns:
Dictionary mapping subject IDs to metadata dictionaries
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
"""
df = pd.read_excel(subject_path)
subject_info = {}
for index, row in df.iterrows():
subject_id = str(row["subject"]).strip()
subject_info[subject_id] = {
"age": int(row["age"]) if pd.notna(row["age"]) else None,
"sex": int(row["sex (F=1)"]) if pd.notna(row["sex (F=1)"]) else None,
}
return subject_info
# =============================================================================
# Embedding Extractor Class
# =============================================================================
class SBERT_Metadata:
"""
Extracts fixed-dimensional embeddings from user metadata using SBERT.
Uses Sentence-BERT to convert textualized metadata descriptions into dense
vector representations. The textualization process converts age and sex
information into natural language, which SBERT then encodes semantically.
Architecture:
User Metadata → Textualization → SBERT Encoder → Embedding
Processing Pipeline:
1. Load user metadata (age, sex) from XLS file
2. Textualize: Convert metadata to natural language description
3. SBERT encoding: Generate 384-dimensional semantic embeddings
4. Output: Fixed-size embedding vector per user
Model: all-MiniLM-L6-v2
- Lightweight BERT variant optimized for sentence embeddings
- 384-dimensional output embeddings
- Fast inference with good semantic understanding
Attributes:
model: SentenceTransformer instance for encoding textualized metadata
"""
def __init__(self):
"""
Initialize the SBERT embedder with pre-trained model.
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
optimized for sentence similarity tasks.
"""
self.model = SentenceTransformer("all-MiniLM-L6-v2")
@staticmethod
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
"""
Discover all user/session directories under data_root.
Uses glob pattern matching for cleaner directory traversal.
Expected structure: data_root/user_id/session_id/
Args:
data_root: Root directory containing user/session subfolders
Returns:
List of (user_id, session_id, session_path) tuples
"""
discovered_paths = []
# Use glob to find all session directories (2 levels deep)
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
if not os.path.isdir(session_path):
continue
# Extract user_id and session_id from path
session_id = os.path.basename(session_path)
user_id = os.path.basename(os.path.dirname(session_path))
discovered_paths.append((user_id, session_id, session_path))
return discovered_paths
def textualize_metadata(self, age: Optional[int], sex: Optional[int]) -> str:
"""
Convert user metadata (age and sex) into a natural language description.
This textualization step is crucial for SBERT, which expects text input.
The description provides user demographic information in a structured format.
Args:
age: Age of the user (integer, may be None)
sex: Sex of the user (0 = Male, 1 = Female, may be None)
Returns:
Natural language string describing the user metadata
"""
# Map sex code to text
if sex is not None:
sex_text = "Female" if sex == 1 else "Male"
else:
sex_text = "Unknown"
# Create sentence from metadata
if age is not None:
sentence = f"This is the information of the user, age: {age}, sex: {sex_text}."
else:
sentence = f"This is the information of the user, age: unknown, sex: {sex_text}."
return sentence
def compute_embedding_from_metadata(
self,
ages: List[Optional[int]],
sexes: List[Optional[int]]
) -> np.ndarray:
"""
Generate embedding vectors from metadata using SBERT.
Processing Pipeline:
1. Textualize each user's metadata into natural language
2. Encode textual descriptions using SBERT
3. Return fixed-size embedding vectors
Args:
ages: List of age values (may contain None)
sexes: List of sex values (0 = Male, 1 = Female, may contain None)
Returns:
Embedding array of shape (batch_size, embedding_dim)
For all-MiniLM-L6-v2: (batch_size, 384)
"""
# Convert metadata to text sentences
text_samples = []
for age, sex in zip(ages, sexes):
text_samples.append(self.textualize_metadata(age, sex))
# Encode text descriptions using SBERT
# Returns numpy array of shape (batch_size, 384)
embeddings = self.model.encode(text_samples)
return embeddings
def extract_embeddings(
self,
data_root: str,
subject_path: str = SUBJECT_PATH,
batch_size: int = 32,
label: Optional[str] = None
) -> Dataset:
"""
Extract embeddings from user metadata for all sessions.
Iterates through all user/session combinations, loads metadata for each user,
generates embeddings from metadata sentences, and aggregates results.
Can process a single label or all labels.
Args:
data_root: Root directory containing user/session data folders
subject_path: Path to SC-subjects.xls file with metadata
batch_size: Number of samples to process together (for batching embeddings)
Larger = faster but more memory.
32 is a good balance for most systems.
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM").
If None or "all", processes all labels.
Returns:
HuggingFace Dataset with columns:
- user_id, session_id, idx, label (metadata from original data)
- embedding (384-dim vector from all-MiniLM-L6-v2 based on metadata)
"""
# Load subject metadata
print(f"[INFO] Loading subject metadata from: {subject_path}")
subject_metadata = load_subject_metadata(subject_path)
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
session_paths = self.discover_session_paths(data_root)
print(f"[INFO] Discovered {len(session_paths)} sessions")
all_embeddings = []
all_user_ids = []
all_session_ids = []
all_idxs = []
all_labels = []
all_ages = []
all_sexes = []
# Collect metadata for all samples first
for user_id, session_id, session_path in session_paths:
# Load HuggingFace dataset from disk
dataset = load_from_disk(session_path)
# Shuffle dataset for randomness
dataset = dataset.shuffle(seed=0)
# Filter by sleep stage label if specified
if label is not None and label != "all":
dataset = dataset.filter(lambda x: x["label"] == label)
num_samples = len(dataset)
if num_samples == 0:
continue
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
# Get metadata for this user
# Convert user_id to string format that matches metadata keys
user_id_str = str(int(user_id))
try:
age = subject_metadata[user_id_str]["age"]
sex = subject_metadata[user_id_str]["sex"]
except KeyError:
age = None
sex = None
print(f"[WARN] No metadata found for user_id: {user_id_str}")
# Collect all samples for this user/session
for i in range(num_samples):
all_user_ids.append(str(dataset["user_id"][i]))
all_session_ids.append(str(dataset["session_id"][i]))
all_idxs.append(int(dataset["idx"][i]))
all_labels.append(str(dataset["label"][i]))
all_ages.append(age)
all_sexes.append(sex)
# Generate embeddings from metadata in batches
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
for batch_start in range(0, len(all_ages), batch_size):
batch_end = min(batch_start + batch_size, len(all_ages))
batch_ages = all_ages[batch_start:batch_end]
batch_sexes = all_sexes[batch_start:batch_end]
# Compute embeddings from metadata
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
# Collect embeddings
for i in range(embeddings.shape[0]):
all_embeddings.append(embeddings[i].tolist())
# Create HuggingFace Dataset
result_dataset = Dataset.from_dict({
"user_id": all_user_ids,
"session_id": all_session_ids,
"idx": all_idxs,
"label": all_labels,
"age": all_ages,
"sex": all_sexes,
"embedding": all_embeddings,
})
print(f"[INFO] Total samples: {len(result_dataset)}")
# Print label distribution
if "label" in result_dataset.column_names:
label_counts = {}
for lbl in result_dataset["label"]:
label_counts[lbl] = label_counts.get(lbl, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return result_dataset
def save_embeddings(
self,
dataset: Dataset,
output_dir: str
) -> None:
"""
Save embeddings dataset to disk in HuggingFace format.
Args:
dataset: HuggingFace Dataset containing embeddings and metadata
output_dir: Directory path to save the dataset
"""
os.makedirs(output_dir, exist_ok=True)
dataset.save_to_disk(output_dir)
print(f"[DONE] Saved embeddings dataset: {output_dir}")
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
@staticmethod
def load_embeddings(embedding_dir: str) -> Dataset:
"""
Load saved embeddings dataset from disk.
Args:
embedding_dir: Directory path containing the saved HuggingFace dataset
Returns:
HuggingFace Dataset with embeddings and metadata
"""
dataset = load_from_disk(embedding_dir)
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
return dataset
# =============================================================================
# Data Loading Utilities
# =============================================================================
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
"""
Load embeddings from all label subdirectories and concatenate them.
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
and loads embeddings from each, then concatenates them into a single dataset.
Args:
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
Returns:
Concatenated HuggingFace Dataset with all labels combined
"""
# Discover all label subdirectories
label_dirs = []
for item in os.listdir(embeddings_root):
item_path = os.path.join(embeddings_root, item)
if os.path.isdir(item_path):
# Check if it's a valid HuggingFace dataset directory
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
label_dirs.append((item, item_path))
if len(label_dirs) == 0:
raise ValueError(
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
)
label_dirs.sort() # Sort for consistent ordering
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
# Load datasets from each label directory
datasets = []
for label_name, label_path in label_dirs:
print(f"[INFO] Loading embeddings from: {label_path}")
dataset = load_from_disk(label_path)
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
datasets.append(dataset)
# Concatenate all datasets
if len(datasets) == 1:
combined_dataset = datasets[0]
else:
combined_dataset = concatenate_datasets(datasets)
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
# Print label distribution
if "label" in combined_dataset.column_names:
label_counts = {}
for label in combined_dataset["label"]:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"[INFO] Label distribution: {label_counts}")
return combined_dataset
# =============================================================================
# Visualization Functions
# =============================================================================
def reduce_to_2d_tsne(
embeddings: np.ndarray,
perplexity: float = 30.0
) -> np.ndarray:
"""
Reduce high-dimensional embeddings to 2D using t-SNE.
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
dimensionality reduction technique that preserves local structure.
Points that are similar in high dimensions stay close in 2D.
Args:
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
perplexity: t-SNE perplexity parameter (typically 5-50).
Higher values consider more neighbors, creating smoother layouts.
Rule of thumb: perplexity ~ sqrt(num_samples)
Returns:
2D coordinates of shape (num_samples, 2)
"""
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
tsne = TSNE(
n_components=2,
random_state=0, # For reproducibility
perplexity=perplexity,
max_iter=1000, # Usually sufficient for convergence
init='random',
learning_rate='auto', # Let sklearn choose optimal learning rate
)
return tsne.fit_transform(embeddings)
def create_scatter_plot_by_age(
coordinates: np.ndarray,
ages: np.ndarray,
title: str,
output_path: str
) -> None:
"""
Create and save a 2D scatter plot colored by age (continuous colormap).
Uses a continuous colormap (viridis) to show age distribution.
Ages are mapped to colors on a gradient scale.
Args:
coordinates: 2D array of shape (num_points, 2)
ages: Age values for each point (array, may contain None values)
title: Plot title
output_path: File path to save the plot (PDF vector format recommended)
"""
# Filter out points with missing age data
# Convert None values to NaN for proper numpy handling
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
valid_mask = ~np.isnan(ages_float)
valid_coords = coordinates[valid_mask]
valid_ages = ages_float[valid_mask]
if len(valid_ages) == 0:
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
return
# Create figure
fig, ax = plt.subplots(figsize=(10, 8))
# Create scatter plot with continuous colormap
scatter = ax.scatter(
valid_coords[:, 0],
valid_coords[:, 1],
c=valid_ages,
cmap='viridis', # Continuous colormap for granular age visualization
s=15,
alpha=0.7,
edgecolors='none',
)
# Add colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Age (years)', rotation=270, labelpad=20)
# Configure plot appearance
ax.set_title(title)
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")
# Save figure as vector PDF
plt.tight_layout()
plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.close()
print(f"[DONE] Saved plot: {output_path}")
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
# =============================================================================
# Command Line Interface
# =============================================================================
class CLI:
"""
Command-line interface for SBERT metadata-based embedding extraction and visualization.
Provides two main commands:
- extract: Generate embeddings from user metadata (age, sex)
- plot: Visualize embeddings with t-SNE colored by age
"""
def extract(
self,
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
subject_path: str = SUBJECT_PATH,
out_dir: str = "./embeddings/all_labels",
batch_size: int = 32,
label: str = None
) -> None:
"""
Extract embeddings from user metadata.
Args:
data_root: Root directory containing user/session data folders
subject_path: Path to SC-subjects.xls file with metadata
out_dir: Output directory for HuggingFace dataset
batch_size: Batch size for inference (default: 32)
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
If None or 'all', processes all labels (default: None for all labels)
"""
embedder = SBERT_Metadata()
dataset = embedder.extract_embeddings(data_root, subject_path, batch_size, label)
embedder.save_embeddings(dataset, out_dir)
def plot(
self,
emb_dir: str = None,
embeddings_root: str = "./embeddings",
out_dir: str = "./plots/all_labels",
perplexity: float = 10.0,
users: str = None,
num_users: int = 0,
labels: str = None,
) -> None:
"""
Visualize embeddings with t-SNE, colored by age.
Can load from either a single label directory or all label directories.
Args:
emb_dir: Single directory containing the HuggingFace embeddings dataset
(e.g., "./embeddings/REM"). If provided, only this directory is used.
embeddings_root: Root directory containing label subdirectories
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
Used only if emb_dir is not provided.
out_dir: Output directory for visualization plots (PDF)
perplexity: t-SNE perplexity parameter (default: 10.0)
users: Comma-separated user IDs to include (e.g., '00,01,02')
num_users: Include only first N users, 0 = all (default: 0)
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
This filters the already-loaded data, not which directories to load.
"""
os.makedirs(out_dir, exist_ok=True)
# Load embeddings: either from single directory or all label directories
if emb_dir is not None:
# Load from single directory
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
dataset = SBERT_Metadata.load_embeddings(emb_dir)
else:
# Load from all label directories
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
dataset = load_embeddings_from_all_labels(embeddings_root)
# Apply user filtering
if users:
user_list = [u.strip() for u in users.split(",")]
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
print(f"[INFO] Filtered to users: {user_list}")
elif num_users > 0:
all_users = sorted(set(dataset["user_id"]))
selected_users = all_users[:num_users]
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
print(f"[INFO] Selected first {num_users} users: {selected_users}")
# Filter by sleep stage labels
if labels:
label_list = [l.strip() for l in labels.split(",")]
dataset = dataset.filter(lambda x: x["label"] in label_list)
print(f"[INFO] Filtered to labels: {label_list}")
print(f"[INFO] Total samples: {len(dataset)}")
# Extract embeddings as numpy array for t-SNE
embeddings = np.array(dataset["embedding"])
# Extract ages for coloring
ages = np.array(dataset["age"])
# Reduce to 2D with t-SNE
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
# Generate visualization colored by age
create_scatter_plot_by_age(
coordinates_2d,
ages,
"t-SNE Visualization (Colored by Age)",
os.path.join(out_dir, "tsne_by_age.pdf")
)
if __name__ == "__main__":
Fire(CLI)