implemented SBERT-based analysis
This commit is contained in:
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,8 +1,8 @@
|
|||||||
.pdf
|
*.pdf
|
||||||
.txt
|
*.txt
|
||||||
.csv
|
*.csv
|
||||||
.arrow
|
*.arrow
|
||||||
.json
|
*.json
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|||||||
@@ -416,9 +416,9 @@ class CLI:
|
|||||||
def extract(
|
def extract(
|
||||||
self,
|
self,
|
||||||
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
|
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF",
|
||||||
out_dir: str = "./embeddings/W",
|
out_dir: str = "./embeddings/REM",
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
label: str = "W"
|
label: str = "REM"
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Extract embeddings from time series data.
|
Extract embeddings from time series data.
|
||||||
|
|||||||
529
analysis/user_similarity/sbert/gen_plot_metadata.py
Normal file
529
analysis/user_similarity/sbert/gen_plot_metadata.py
Normal file
@@ -0,0 +1,529 @@
|
|||||||
|
"""
|
||||||
|
SBERT Embedding Extraction and Visualization with Metadata (Age and Sex)
|
||||||
|
|
||||||
|
This module provides functionality to:
|
||||||
|
1. Extract SBERT embeddings from time series features
|
||||||
|
2. Visualize embeddings colored by subject metadata (age and sex) instead of user IDs
|
||||||
|
|
||||||
|
Features:
|
||||||
|
1. Extract embeddings from time series features using SBERT
|
||||||
|
2. Load embeddings from HuggingFace dataset
|
||||||
|
3. Load subject metadata from XLS file
|
||||||
|
4. Map user IDs to age and sex information
|
||||||
|
5. Visualize embeddings with t-SNE colored by age (continuous) and sex (categorical)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Extract embeddings for all labels
|
||||||
|
python gen_plot_metadata.py extract --data_root /path/to/data --out_dir ./embeddings/all_labels
|
||||||
|
|
||||||
|
# Extract embeddings for a single label
|
||||||
|
python gen_plot_metadata.py extract --data_root /path/to/data --out_dir ./embeddings/W --label W
|
||||||
|
|
||||||
|
# Visualize by age from single label directory
|
||||||
|
python gen_plot_metadata.py plot --emb_dir ./embeddings/W --out_dir ./plots/W --color_by age
|
||||||
|
|
||||||
|
# Visualize by age from all label directories (W, REM, N1, N2, N3)
|
||||||
|
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by age
|
||||||
|
|
||||||
|
# Visualize by sex from all labels
|
||||||
|
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by sex
|
||||||
|
|
||||||
|
# Create both age and sex plots from all labels
|
||||||
|
python gen_plot_metadata.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels --color_by both
|
||||||
|
|
||||||
|
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||||
|
Date: 2026-01-09
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from datasets import load_from_disk, Dataset, concatenate_datasets
|
||||||
|
from fire import Fire
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
|
||||||
|
from gen_plot import SBERT, reduce_to_2d_tsne
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Constants
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
SUBJECT_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/SC-subjects.xls"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Metadata Loading
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Load subject metadata from XLS file.
|
||||||
|
|
||||||
|
The XLS file contains subject information including:
|
||||||
|
- subject: Subject ID (e.g., "SC4001", "SC4002")
|
||||||
|
- age: Age of the subject
|
||||||
|
- sex (F=1): Sex (1 = Female, 0 = Male)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subject_path: Path to the SC-subjects.xls file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping subject IDs to metadata dictionaries
|
||||||
|
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
|
||||||
|
"""
|
||||||
|
df = pd.read_excel(subject_path)
|
||||||
|
subject_info = {}
|
||||||
|
|
||||||
|
for index, row in df.iterrows():
|
||||||
|
subject_id = str(row["subject"]).strip()
|
||||||
|
subject_info[subject_id] = {
|
||||||
|
"age": int(row["age"]) if pd.notna(row["age"]) else None,
|
||||||
|
"sex": int(row["sex (F=1)"]) if pd.notna(row["sex (F=1)"]) else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return subject_info
|
||||||
|
|
||||||
|
|
||||||
|
def map_user_ids_to_metadata(
|
||||||
|
user_ids: np.ndarray,
|
||||||
|
subject_metadata: Dict[str, Dict[str, Any]]
|
||||||
|
) -> tuple:
|
||||||
|
"""
|
||||||
|
Map user IDs from dataset to age and sex metadata.
|
||||||
|
|
||||||
|
User IDs in the dataset are typically 2-digit codes (e.g., "40", "41").
|
||||||
|
Subject IDs in the metadata file are typically 4-character codes (e.g., "SC40", "SC41").
|
||||||
|
We need to match them appropriately.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_ids: Array of user IDs from the dataset
|
||||||
|
subject_metadata: Dictionary of subject metadata loaded from XLS file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (ages, sexes) as numpy arrays
|
||||||
|
Missing values are set to None
|
||||||
|
"""
|
||||||
|
ages = []
|
||||||
|
sexes = []
|
||||||
|
|
||||||
|
for user_id in user_ids:
|
||||||
|
age = subject_metadata[user_id]["age"]
|
||||||
|
sex = subject_metadata[user_id]["sex"]
|
||||||
|
ages.append(age)
|
||||||
|
sexes.append(sex)
|
||||||
|
|
||||||
|
return np.array(ages), np.array(sexes)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Visualization Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def create_scatter_plot_by_age(
|
||||||
|
coordinates: np.ndarray,
|
||||||
|
ages: np.ndarray,
|
||||||
|
title: str,
|
||||||
|
output_path: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create and save a 2D scatter plot colored by age (continuous colormap).
|
||||||
|
|
||||||
|
Uses a continuous colormap (viridis) to show age distribution.
|
||||||
|
Ages are mapped to colors on a gradient scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coordinates: 2D array of shape (num_points, 2)
|
||||||
|
ages: Age values for each point (array, may contain None values)
|
||||||
|
title: Plot title
|
||||||
|
output_path: File path to save the plot (PDF vector format recommended)
|
||||||
|
"""
|
||||||
|
# Filter out points with missing age data
|
||||||
|
# Convert None values to NaN for proper numpy handling
|
||||||
|
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
|
||||||
|
valid_mask = ~np.isnan(ages_float)
|
||||||
|
valid_coords = coordinates[valid_mask]
|
||||||
|
valid_ages = ages_float[valid_mask]
|
||||||
|
|
||||||
|
if len(valid_ages) == 0:
|
||||||
|
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create figure
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 8))
|
||||||
|
|
||||||
|
# Create scatter plot with continuous colormap
|
||||||
|
scatter = ax.scatter(
|
||||||
|
valid_coords[:, 0],
|
||||||
|
valid_coords[:, 1],
|
||||||
|
c=valid_ages,
|
||||||
|
cmap='viridis', # Continuous colormap for granular age visualization
|
||||||
|
s=15,
|
||||||
|
alpha=0.7,
|
||||||
|
edgecolors='none',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add colorbar
|
||||||
|
cbar = plt.colorbar(scatter, ax=ax)
|
||||||
|
cbar.set_label('Age (years)', rotation=270, labelpad=20)
|
||||||
|
|
||||||
|
# Configure plot appearance
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.set_xlabel("Dimension 1")
|
||||||
|
ax.set_ylabel("Dimension 2")
|
||||||
|
|
||||||
|
# Save figure as vector PDF
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
print(f"[DONE] Saved plot: {output_path}")
|
||||||
|
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
|
||||||
|
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_scatter_plot_by_sex(
|
||||||
|
coordinates: np.ndarray,
|
||||||
|
sexes: np.ndarray,
|
||||||
|
title: str,
|
||||||
|
output_path: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create and save a 2D scatter plot colored by sex (categorical).
|
||||||
|
|
||||||
|
Uses discrete colors for different sex categories.
|
||||||
|
Sex encoding: 1 = Female, 0 = Male
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coordinates: 2D array of shape (num_points, 2)
|
||||||
|
sexes: Sex values for each point (array, may contain None values)
|
||||||
|
title: Plot title
|
||||||
|
output_path: File path to save the plot (PDF vector format recommended)
|
||||||
|
"""
|
||||||
|
# Filter out points with missing sex data
|
||||||
|
# Convert None values to NaN for proper numpy handling
|
||||||
|
sexes_float = np.array([float(s) if s is not None else np.nan for s in sexes])
|
||||||
|
valid_mask = ~np.isnan(sexes_float)
|
||||||
|
valid_coords = coordinates[valid_mask]
|
||||||
|
valid_sexes = sexes_float[valid_mask].astype(int)
|
||||||
|
|
||||||
|
if len(valid_sexes) == 0:
|
||||||
|
print(f"[WARN] No valid sex data found. Skipping plot: {output_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Map sex codes to labels
|
||||||
|
sex_labels = {0: "Male", 1: "Female"}
|
||||||
|
unique_sexes = sorted(set(valid_sexes))
|
||||||
|
|
||||||
|
# Create figure
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 8))
|
||||||
|
|
||||||
|
# Plot each category separately for proper legend
|
||||||
|
colors = ['steelblue', 'coral'] # Blue for Male, Coral for Female
|
||||||
|
for idx, sex_code in enumerate(unique_sexes):
|
||||||
|
mask = valid_sexes == sex_code
|
||||||
|
ax.scatter(
|
||||||
|
valid_coords[mask, 0],
|
||||||
|
valid_coords[mask, 1],
|
||||||
|
c=colors[sex_code % len(colors)],
|
||||||
|
s=15,
|
||||||
|
label=sex_labels.get(sex_code, f"Sex {sex_code}"),
|
||||||
|
alpha=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure plot appearance
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.set_xlabel("Dimension 1")
|
||||||
|
ax.set_ylabel("Dimension 2")
|
||||||
|
ax.legend(loc='best', markerscale=2)
|
||||||
|
|
||||||
|
# Save figure as vector PDF
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
print(f"[DONE] Saved plot: {output_path}")
|
||||||
|
sex_counts = {sex_labels.get(s, f"Sex {s}"): (valid_sexes == s).sum() for s in unique_sexes}
|
||||||
|
print(f"[INFO] Sex distribution: {sex_counts}")
|
||||||
|
print(f"[INFO] Points with valid sex: {len(valid_sexes)}/{len(sexes)}")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Embedding Extraction Utilities
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def extract_embeddings_for_all_labels(
|
||||||
|
data_root: str,
|
||||||
|
out_dir: str,
|
||||||
|
batch_size: int = 32
|
||||||
|
) -> Dataset:
|
||||||
|
"""
|
||||||
|
Extract embeddings for all sleep stage labels and combine them.
|
||||||
|
|
||||||
|
Extracts embeddings for each label (W, REM, N1, N2, N3) separately,
|
||||||
|
then concatenates them into a single dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root: Root directory containing user/session data folders
|
||||||
|
out_dir: Output directory for the combined HuggingFace dataset
|
||||||
|
batch_size: Batch size for inference (default: 32)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined HuggingFace Dataset with all labels
|
||||||
|
"""
|
||||||
|
embedder = SBERT()
|
||||||
|
all_labels = ["W", "REM", "N1", "N2", "N3"]
|
||||||
|
|
||||||
|
datasets = []
|
||||||
|
for label in all_labels:
|
||||||
|
print(f"\n[INFO] Extracting embeddings for label: {label}")
|
||||||
|
dataset = embedder.extract_embeddings(data_root, batch_size, label)
|
||||||
|
print(f"[INFO] Extracted {len(dataset)} samples for label {label}")
|
||||||
|
datasets.append(dataset)
|
||||||
|
|
||||||
|
# Concatenate all datasets
|
||||||
|
if len(datasets) == 1:
|
||||||
|
combined_dataset = datasets[0]
|
||||||
|
else:
|
||||||
|
combined_dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
|
print(f"\n[INFO] Combined dataset: {len(combined_dataset)} total samples")
|
||||||
|
|
||||||
|
# Print label distribution
|
||||||
|
if "label" in combined_dataset.column_names:
|
||||||
|
label_counts = {}
|
||||||
|
for label in combined_dataset["label"]:
|
||||||
|
label_counts[label] = label_counts.get(label, 0) + 1
|
||||||
|
print(f"[INFO] Label distribution: {label_counts}")
|
||||||
|
|
||||||
|
# Save combined dataset
|
||||||
|
embedder.save_embeddings(combined_dataset, out_dir)
|
||||||
|
|
||||||
|
return combined_dataset
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Data Loading Utilities
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
|
||||||
|
"""
|
||||||
|
Load embeddings from all label subdirectories and concatenate them.
|
||||||
|
|
||||||
|
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
|
||||||
|
and loads embeddings from each, then concatenates them into a single dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings_root: Root directory containing label subdirectories
|
||||||
|
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Concatenated HuggingFace Dataset with all labels combined
|
||||||
|
"""
|
||||||
|
# Discover all label subdirectories
|
||||||
|
label_dirs = []
|
||||||
|
for item in os.listdir(embeddings_root):
|
||||||
|
item_path = os.path.join(embeddings_root, item)
|
||||||
|
if os.path.isdir(item_path):
|
||||||
|
# Check if it's a valid HuggingFace dataset directory
|
||||||
|
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
|
||||||
|
label_dirs.append((item, item_path))
|
||||||
|
|
||||||
|
if len(label_dirs) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
|
||||||
|
)
|
||||||
|
|
||||||
|
label_dirs.sort() # Sort for consistent ordering
|
||||||
|
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
|
||||||
|
|
||||||
|
# Load datasets from each label directory
|
||||||
|
datasets = []
|
||||||
|
for label_name, label_path in label_dirs:
|
||||||
|
print(f"[INFO] Loading embeddings from: {label_path}")
|
||||||
|
dataset = load_from_disk(label_path)
|
||||||
|
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
|
||||||
|
datasets.append(dataset)
|
||||||
|
|
||||||
|
# Concatenate all datasets
|
||||||
|
if len(datasets) == 1:
|
||||||
|
combined_dataset = datasets[0]
|
||||||
|
else:
|
||||||
|
combined_dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
|
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
|
||||||
|
|
||||||
|
# Print label distribution
|
||||||
|
if "label" in combined_dataset.column_names:
|
||||||
|
label_counts = {}
|
||||||
|
for label in combined_dataset["label"]:
|
||||||
|
label_counts[label] = label_counts.get(label, 0) + 1
|
||||||
|
print(f"[INFO] Label distribution: {label_counts}")
|
||||||
|
|
||||||
|
return combined_dataset
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Command Line Interface
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class CLI:
|
||||||
|
"""
|
||||||
|
Command-line interface for SBERT embedding extraction and visualization with metadata.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- extract: Generate embeddings from time series features
|
||||||
|
- plot: Visualize embeddings colored by age or sex instead of user ID
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract(
|
||||||
|
self,
|
||||||
|
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
|
||||||
|
out_dir: str = "./embeddings/all_labels",
|
||||||
|
batch_size: int = 32,
|
||||||
|
label: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Extract embeddings from time series features using SBERT.
|
||||||
|
|
||||||
|
Can extract for a single label or all labels at once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root: Root directory containing user/session data folders
|
||||||
|
out_dir: Output directory for HuggingFace dataset
|
||||||
|
batch_size: Batch size for inference (default: 32)
|
||||||
|
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
|
||||||
|
If None or 'all', extracts for all labels (default: None for all labels)
|
||||||
|
"""
|
||||||
|
embedder = SBERT()
|
||||||
|
|
||||||
|
if label is None or label == "all":
|
||||||
|
# Extract for all labels
|
||||||
|
print(f"[INFO] Extracting embeddings for all labels")
|
||||||
|
extract_embeddings_for_all_labels(data_root, out_dir, batch_size)
|
||||||
|
else:
|
||||||
|
# Extract for single label
|
||||||
|
print(f"[INFO] Extracting embeddings for label: {label}")
|
||||||
|
dataset = embedder.extract_embeddings(data_root, batch_size, label)
|
||||||
|
embedder.save_embeddings(dataset, out_dir)
|
||||||
|
|
||||||
|
def plot(
|
||||||
|
self,
|
||||||
|
emb_dir: str = None,
|
||||||
|
embeddings_root: str = "./embeddings",
|
||||||
|
out_dir: str = "./plots/all_labels",
|
||||||
|
subject_path: str = SUBJECT_PATH,
|
||||||
|
perplexity: float = 10.0,
|
||||||
|
color_by: str = "age",
|
||||||
|
users: str = None,
|
||||||
|
num_users: int = 0,
|
||||||
|
labels: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Visualize embeddings with t-SNE, colored by age or sex.
|
||||||
|
|
||||||
|
Can load from either a single label directory or all label directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emb_dir: Single directory containing the HuggingFace embeddings dataset
|
||||||
|
(e.g., "./embeddings/W"). If provided, only this directory is used.
|
||||||
|
embeddings_root: Root directory containing label subdirectories
|
||||||
|
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
|
||||||
|
Used only if emb_dir is not provided.
|
||||||
|
out_dir: Output directory for visualization plots (PDF)
|
||||||
|
subject_path: Path to SC-subjects.xls file with metadata
|
||||||
|
perplexity: t-SNE perplexity parameter (default: 10.0)
|
||||||
|
color_by: What to color by - 'age', 'sex', or 'both' (default: 'age')
|
||||||
|
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||||
|
num_users: Include only first N users, 0 = all (default: 0)
|
||||||
|
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||||
|
This filters the already-loaded data, not which directories to load.
|
||||||
|
"""
|
||||||
|
# Validate color_by argument
|
||||||
|
if color_by not in ["age", "sex", "both", "all"]:
|
||||||
|
raise ValueError(f"Invalid color_by: {color_by}. Use 'age', 'sex', or 'both'.")
|
||||||
|
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Load embeddings: either from single directory or all label directories
|
||||||
|
if emb_dir is not None:
|
||||||
|
# Load from single directory
|
||||||
|
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
|
||||||
|
dataset = SBERT.load_embeddings(emb_dir)
|
||||||
|
else:
|
||||||
|
# Load from all label directories
|
||||||
|
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
|
||||||
|
dataset = load_embeddings_from_all_labels(embeddings_root)
|
||||||
|
|
||||||
|
# Apply user filtering
|
||||||
|
if users:
|
||||||
|
user_list = [u.strip() for u in users.split(",")]
|
||||||
|
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
|
||||||
|
print(f"[INFO] Filtered to users: {user_list}")
|
||||||
|
|
||||||
|
elif num_users > 0:
|
||||||
|
all_users = sorted(set(dataset["user_id"]))
|
||||||
|
selected_users = all_users[:num_users]
|
||||||
|
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||||
|
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||||
|
|
||||||
|
# Filter by sleep stage labels
|
||||||
|
if labels:
|
||||||
|
label_list = [l.strip() for l in labels.split(",")]
|
||||||
|
dataset = dataset.filter(lambda x: x["label"] in label_list)
|
||||||
|
print(f"[INFO] Filtered to labels: {label_list}")
|
||||||
|
|
||||||
|
print(f"[INFO] Total samples: {len(dataset)}")
|
||||||
|
|
||||||
|
# Extract embeddings as numpy array for t-SNE
|
||||||
|
embeddings = np.array(dataset["embedding"])
|
||||||
|
|
||||||
|
# Load subject metadata
|
||||||
|
print(f"[INFO] Loading subject metadata from: {subject_path}")
|
||||||
|
subject_metadata = load_subject_metadata(subject_path)
|
||||||
|
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
|
||||||
|
|
||||||
|
# Map user IDs to metadata
|
||||||
|
user_ids = np.array([str(uid) for uid in dataset["user_id"]])
|
||||||
|
user_ids_ = [str(int(uid)) for uid in user_ids]
|
||||||
|
ages, sexes = map_user_ids_to_metadata(user_ids_, subject_metadata)
|
||||||
|
|
||||||
|
# Reduce to 2D with t-SNE
|
||||||
|
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||||
|
|
||||||
|
# Generate visualizations based on color_by parameter
|
||||||
|
if color_by == "age":
|
||||||
|
create_scatter_plot_by_age(
|
||||||
|
coordinates_2d,
|
||||||
|
ages,
|
||||||
|
"t-SNE Visualization (Colored by Age)",
|
||||||
|
os.path.join(out_dir, "tsne_by_age.pdf")
|
||||||
|
)
|
||||||
|
elif color_by == "sex":
|
||||||
|
create_scatter_plot_by_sex(
|
||||||
|
coordinates_2d,
|
||||||
|
sexes,
|
||||||
|
"t-SNE Visualization (Colored by Sex)",
|
||||||
|
os.path.join(out_dir, "tsne_by_sex.pdf")
|
||||||
|
)
|
||||||
|
elif color_by == "both" or color_by == "all":
|
||||||
|
# Create both plots
|
||||||
|
create_scatter_plot_by_age(
|
||||||
|
coordinates_2d,
|
||||||
|
ages,
|
||||||
|
"t-SNE Visualization (Colored by Age)",
|
||||||
|
os.path.join(out_dir, "tsne_by_age.pdf")
|
||||||
|
)
|
||||||
|
create_scatter_plot_by_sex(
|
||||||
|
coordinates_2d,
|
||||||
|
sexes,
|
||||||
|
"t-SNE Visualization (Colored by Sex)",
|
||||||
|
os.path.join(out_dir, "tsne_by_sex.pdf")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Fire(CLI)
|
||||||
655
analysis/user_similarity/sbert_metadata/gen_plot.py
Normal file
655
analysis/user_similarity/sbert_metadata/gen_plot.py
Normal file
@@ -0,0 +1,655 @@
|
|||||||
|
"""
|
||||||
|
SBERT Metadata-Based Embedding Extraction and Visualization Pipeline
|
||||||
|
|
||||||
|
This module provides functionality to
|
||||||
|
1. Extract embeddings from user metadata (age and sex) using SBERT (Sentence-BERT)
|
||||||
|
2. Visualize embeddings using dimensionality reduction (t-SNE) colored by age
|
||||||
|
|
||||||
|
SBERT Overview:
|
||||||
|
SBERT (Sentence-BERT) is a modification of the BERT model that uses siamese
|
||||||
|
and triplet network structures to derive semantically meaningful sentence
|
||||||
|
embeddings. It's designed for semantic similarity tasks and produces
|
||||||
|
fixed-size dense vector representations of text.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Semantic understanding: Captures meaning rather than just word presence
|
||||||
|
- Fixed-size embeddings: Outputs consistent vector dimensions (384 for all-MiniLM-L6-v2)
|
||||||
|
- Efficient: Optimized for sentence-level tasks
|
||||||
|
|
||||||
|
Embedding Strategy:
|
||||||
|
Instead of using time series features, we create textual descriptions based on
|
||||||
|
user metadata (age and sex). This approach allows us to capture user-level
|
||||||
|
characteristics in the embedding space.
|
||||||
|
|
||||||
|
Processing Pipeline:
|
||||||
|
1. Load user metadata from XLS file (age, sex)
|
||||||
|
2. Textualization: Convert metadata to natural language description
|
||||||
|
3. SBERT encoding: Generate 384-dimensional embeddings
|
||||||
|
4. Visualization: t-SNE with continuous age coloring
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Extract embeddings from metadata for all labels
|
||||||
|
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/all_labels
|
||||||
|
|
||||||
|
# Extract embeddings for a single label
|
||||||
|
python gen_plot.py extract --data_root /path/to/data --subject_path /path/to/subjects.xls --out_dir ./embeddings/REM --label REM
|
||||||
|
|
||||||
|
# Visualize with t-SNE from all label directories (colored by age)
|
||||||
|
python gen_plot.py plot --embeddings_root ./embeddings --out_dir ./plots/all_labels
|
||||||
|
|
||||||
|
# Visualize from a single label directory
|
||||||
|
python gen_plot.py plot --emb_dir ./embeddings/REM --out_dir ./plots/REM
|
||||||
|
|
||||||
|
Author: Sumin Im (NMSL Undergraduate Researcher)
|
||||||
|
Date: 2026-01-09
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
from typing import Dict, Any, List, Tuple, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from datasets import load_from_disk, Dataset, concatenate_datasets
|
||||||
|
from fire import Fire
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Constants
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
SUBJECT_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/SC-subjects.xls"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Metadata Loading
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def load_subject_metadata(subject_path: str = SUBJECT_PATH) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Load subject metadata from XLS file.
|
||||||
|
|
||||||
|
The XLS file contains subject information including:
|
||||||
|
- subject: Subject ID (e.g., "SC4001", "SC4002")
|
||||||
|
- age: Age of the subject
|
||||||
|
- sex (F=1): Sex (1 = Female, 0 = Male)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subject_path: Path to the SC-subjects.xls file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping subject IDs to metadata dictionaries
|
||||||
|
Format: {"SC4001": {"age": 25, "sex": 1}, ...}
|
||||||
|
"""
|
||||||
|
df = pd.read_excel(subject_path)
|
||||||
|
subject_info = {}
|
||||||
|
|
||||||
|
for index, row in df.iterrows():
|
||||||
|
subject_id = str(row["subject"]).strip()
|
||||||
|
subject_info[subject_id] = {
|
||||||
|
"age": int(row["age"]) if pd.notna(row["age"]) else None,
|
||||||
|
"sex": int(row["sex (F=1)"]) if pd.notna(row["sex (F=1)"]) else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return subject_info
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Embedding Extractor Class
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class SBERT_Metadata:
|
||||||
|
"""
|
||||||
|
Extracts fixed-dimensional embeddings from user metadata using SBERT.
|
||||||
|
|
||||||
|
Uses Sentence-BERT to convert textualized metadata descriptions into dense
|
||||||
|
vector representations. The textualization process converts age and sex
|
||||||
|
information into natural language, which SBERT then encodes semantically.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
User Metadata → Textualization → SBERT Encoder → Embedding
|
||||||
|
|
||||||
|
Processing Pipeline:
|
||||||
|
1. Load user metadata (age, sex) from XLS file
|
||||||
|
2. Textualize: Convert metadata to natural language description
|
||||||
|
3. SBERT encoding: Generate 384-dimensional semantic embeddings
|
||||||
|
4. Output: Fixed-size embedding vector per user
|
||||||
|
|
||||||
|
Model: all-MiniLM-L6-v2
|
||||||
|
- Lightweight BERT variant optimized for sentence embeddings
|
||||||
|
- 384-dimensional output embeddings
|
||||||
|
- Fast inference with good semantic understanding
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: SentenceTransformer instance for encoding textualized metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the SBERT embedder with pre-trained model.
|
||||||
|
|
||||||
|
Uses "all-MiniLM-L6-v2" which is a lightweight, fast model
|
||||||
|
optimized for sentence similarity tasks.
|
||||||
|
"""
|
||||||
|
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def discover_session_paths(data_root: str) -> List[Tuple[str, str, str]]:
|
||||||
|
"""
|
||||||
|
Discover all user/session directories under data_root.
|
||||||
|
|
||||||
|
Uses glob pattern matching for cleaner directory traversal.
|
||||||
|
Expected structure: data_root/user_id/session_id/
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root: Root directory containing user/session subfolders
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (user_id, session_id, session_path) tuples
|
||||||
|
"""
|
||||||
|
discovered_paths = []
|
||||||
|
|
||||||
|
# Use glob to find all session directories (2 levels deep)
|
||||||
|
for session_path in sorted(glob(os.path.join(data_root, "*", "*"))):
|
||||||
|
if not os.path.isdir(session_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract user_id and session_id from path
|
||||||
|
session_id = os.path.basename(session_path)
|
||||||
|
user_id = os.path.basename(os.path.dirname(session_path))
|
||||||
|
|
||||||
|
discovered_paths.append((user_id, session_id, session_path))
|
||||||
|
|
||||||
|
return discovered_paths
|
||||||
|
|
||||||
|
def textualize_metadata(self, age: Optional[int], sex: Optional[int]) -> str:
|
||||||
|
"""
|
||||||
|
Convert user metadata (age and sex) into a natural language description.
|
||||||
|
|
||||||
|
This textualization step is crucial for SBERT, which expects text input.
|
||||||
|
The description provides user demographic information in a structured format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
age: Age of the user (integer, may be None)
|
||||||
|
sex: Sex of the user (0 = Male, 1 = Female, may be None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Natural language string describing the user metadata
|
||||||
|
"""
|
||||||
|
# Map sex code to text
|
||||||
|
if sex is not None:
|
||||||
|
sex_text = "Female" if sex == 1 else "Male"
|
||||||
|
else:
|
||||||
|
sex_text = "Unknown"
|
||||||
|
|
||||||
|
# Create sentence from metadata
|
||||||
|
if age is not None:
|
||||||
|
sentence = f"This is the information of the user, age: {age}, sex: {sex_text}."
|
||||||
|
else:
|
||||||
|
sentence = f"This is the information of the user, age: unknown, sex: {sex_text}."
|
||||||
|
|
||||||
|
return sentence
|
||||||
|
|
||||||
|
def compute_embedding_from_metadata(
|
||||||
|
self,
|
||||||
|
ages: List[Optional[int]],
|
||||||
|
sexes: List[Optional[int]]
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Generate embedding vectors from metadata using SBERT.
|
||||||
|
|
||||||
|
Processing Pipeline:
|
||||||
|
1. Textualize each user's metadata into natural language
|
||||||
|
2. Encode textual descriptions using SBERT
|
||||||
|
3. Return fixed-size embedding vectors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ages: List of age values (may contain None)
|
||||||
|
sexes: List of sex values (0 = Male, 1 = Female, may contain None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding array of shape (batch_size, embedding_dim)
|
||||||
|
For all-MiniLM-L6-v2: (batch_size, 384)
|
||||||
|
"""
|
||||||
|
# Convert metadata to text sentences
|
||||||
|
text_samples = []
|
||||||
|
for age, sex in zip(ages, sexes):
|
||||||
|
text_samples.append(self.textualize_metadata(age, sex))
|
||||||
|
|
||||||
|
# Encode text descriptions using SBERT
|
||||||
|
# Returns numpy array of shape (batch_size, 384)
|
||||||
|
embeddings = self.model.encode(text_samples)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def extract_embeddings(
|
||||||
|
self,
|
||||||
|
data_root: str,
|
||||||
|
subject_path: str = SUBJECT_PATH,
|
||||||
|
batch_size: int = 32,
|
||||||
|
label: Optional[str] = None
|
||||||
|
) -> Dataset:
|
||||||
|
"""
|
||||||
|
Extract embeddings from user metadata for all sessions.
|
||||||
|
|
||||||
|
Iterates through all user/session combinations, loads metadata for each user,
|
||||||
|
generates embeddings from metadata sentences, and aggregates results.
|
||||||
|
Can process a single label or all labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root: Root directory containing user/session data folders
|
||||||
|
subject_path: Path to SC-subjects.xls file with metadata
|
||||||
|
batch_size: Number of samples to process together (for batching embeddings)
|
||||||
|
Larger = faster but more memory.
|
||||||
|
32 is a good balance for most systems.
|
||||||
|
label: Sleep stage label to filter (e.g., "W", "N1", "N2", "N3", "REM").
|
||||||
|
If None or "all", processes all labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HuggingFace Dataset with columns:
|
||||||
|
- user_id, session_id, idx, label (metadata from original data)
|
||||||
|
- embedding (384-dim vector from all-MiniLM-L6-v2 based on metadata)
|
||||||
|
"""
|
||||||
|
# Load subject metadata
|
||||||
|
print(f"[INFO] Loading subject metadata from: {subject_path}")
|
||||||
|
subject_metadata = load_subject_metadata(subject_path)
|
||||||
|
print(f"[INFO] Loaded metadata for {len(subject_metadata)} subjects")
|
||||||
|
|
||||||
|
session_paths = self.discover_session_paths(data_root)
|
||||||
|
print(f"[INFO] Discovered {len(session_paths)} sessions")
|
||||||
|
|
||||||
|
all_embeddings = []
|
||||||
|
all_user_ids = []
|
||||||
|
all_session_ids = []
|
||||||
|
all_idxs = []
|
||||||
|
all_labels = []
|
||||||
|
all_ages = []
|
||||||
|
all_sexes = []
|
||||||
|
|
||||||
|
# Collect metadata for all samples first
|
||||||
|
for user_id, session_id, session_path in session_paths:
|
||||||
|
# Load HuggingFace dataset from disk
|
||||||
|
dataset = load_from_disk(session_path)
|
||||||
|
# Shuffle dataset for randomness
|
||||||
|
dataset = dataset.shuffle(seed=0)
|
||||||
|
# Filter by sleep stage label if specified
|
||||||
|
if label is not None and label != "all":
|
||||||
|
dataset = dataset.filter(lambda x: x["label"] == label)
|
||||||
|
num_samples = len(dataset)
|
||||||
|
|
||||||
|
if num_samples == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"[INFO] Processing user={user_id}, session={session_id}, samples={num_samples}")
|
||||||
|
|
||||||
|
# Get metadata for this user
|
||||||
|
# Convert user_id to string format that matches metadata keys
|
||||||
|
user_id_str = str(int(user_id))
|
||||||
|
try:
|
||||||
|
age = subject_metadata[user_id_str]["age"]
|
||||||
|
sex = subject_metadata[user_id_str]["sex"]
|
||||||
|
except KeyError:
|
||||||
|
age = None
|
||||||
|
sex = None
|
||||||
|
print(f"[WARN] No metadata found for user_id: {user_id_str}")
|
||||||
|
|
||||||
|
# Collect all samples for this user/session
|
||||||
|
for i in range(num_samples):
|
||||||
|
all_user_ids.append(str(dataset["user_id"][i]))
|
||||||
|
all_session_ids.append(str(dataset["session_id"][i]))
|
||||||
|
all_idxs.append(int(dataset["idx"][i]))
|
||||||
|
all_labels.append(str(dataset["label"][i]))
|
||||||
|
all_ages.append(age)
|
||||||
|
all_sexes.append(sex)
|
||||||
|
|
||||||
|
# Generate embeddings from metadata in batches
|
||||||
|
print(f"[INFO] Generating embeddings from metadata for {len(all_ages)} samples...")
|
||||||
|
for batch_start in range(0, len(all_ages), batch_size):
|
||||||
|
batch_end = min(batch_start + batch_size, len(all_ages))
|
||||||
|
|
||||||
|
batch_ages = all_ages[batch_start:batch_end]
|
||||||
|
batch_sexes = all_sexes[batch_start:batch_end]
|
||||||
|
|
||||||
|
# Compute embeddings from metadata
|
||||||
|
embeddings = self.compute_embedding_from_metadata(batch_ages, batch_sexes)
|
||||||
|
|
||||||
|
# Collect embeddings
|
||||||
|
for i in range(embeddings.shape[0]):
|
||||||
|
all_embeddings.append(embeddings[i].tolist())
|
||||||
|
|
||||||
|
# Create HuggingFace Dataset
|
||||||
|
result_dataset = Dataset.from_dict({
|
||||||
|
"user_id": all_user_ids,
|
||||||
|
"session_id": all_session_ids,
|
||||||
|
"idx": all_idxs,
|
||||||
|
"label": all_labels,
|
||||||
|
"age": all_ages,
|
||||||
|
"sex": all_sexes,
|
||||||
|
"embedding": all_embeddings,
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"[INFO] Total samples: {len(result_dataset)}")
|
||||||
|
|
||||||
|
# Print label distribution
|
||||||
|
if "label" in result_dataset.column_names:
|
||||||
|
label_counts = {}
|
||||||
|
for lbl in result_dataset["label"]:
|
||||||
|
label_counts[lbl] = label_counts.get(lbl, 0) + 1
|
||||||
|
print(f"[INFO] Label distribution: {label_counts}")
|
||||||
|
|
||||||
|
return result_dataset
|
||||||
|
|
||||||
|
def save_embeddings(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
output_dir: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save embeddings dataset to disk in HuggingFace format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: HuggingFace Dataset containing embeddings and metadata
|
||||||
|
output_dir: Directory path to save the dataset
|
||||||
|
"""
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
dataset.save_to_disk(output_dir)
|
||||||
|
|
||||||
|
print(f"[DONE] Saved embeddings dataset: {output_dir}")
|
||||||
|
print(f"[DONE] Total samples: {len(dataset)}, Embedding dim: {len(dataset[0]['embedding'])}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_embeddings(embedding_dir: str) -> Dataset:
|
||||||
|
"""
|
||||||
|
Load saved embeddings dataset from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_dir: Directory path containing the saved HuggingFace dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HuggingFace Dataset with embeddings and metadata
|
||||||
|
"""
|
||||||
|
dataset = load_from_disk(embedding_dir)
|
||||||
|
print(f"[INFO] Loaded {len(dataset)} samples from {embedding_dir}")
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Data Loading Utilities
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def load_embeddings_from_all_labels(embeddings_root: str) -> Dataset:
|
||||||
|
"""
|
||||||
|
Load embeddings from all label subdirectories and concatenate them.
|
||||||
|
|
||||||
|
Discovers all subdirectories in embeddings_root (e.g., W, REM, N1, N2, N3)
|
||||||
|
and loads embeddings from each, then concatenates them into a single dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings_root: Root directory containing label subdirectories
|
||||||
|
(e.g., "./embeddings" containing W/, REM/, N1/, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Concatenated HuggingFace Dataset with all labels combined
|
||||||
|
"""
|
||||||
|
# Discover all label subdirectories
|
||||||
|
label_dirs = []
|
||||||
|
for item in os.listdir(embeddings_root):
|
||||||
|
item_path = os.path.join(embeddings_root, item)
|
||||||
|
if os.path.isdir(item_path):
|
||||||
|
# Check if it's a valid HuggingFace dataset directory
|
||||||
|
if os.path.exists(os.path.join(item_path, "dataset_info.json")):
|
||||||
|
label_dirs.append((item, item_path))
|
||||||
|
|
||||||
|
if len(label_dirs) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No valid HuggingFace dataset directories found in: {embeddings_root}"
|
||||||
|
)
|
||||||
|
|
||||||
|
label_dirs.sort() # Sort for consistent ordering
|
||||||
|
print(f"[INFO] Discovered {len(label_dirs)} label directories: {[ld[0] for ld in label_dirs]}")
|
||||||
|
|
||||||
|
# Load datasets from each label directory
|
||||||
|
datasets = []
|
||||||
|
for label_name, label_path in label_dirs:
|
||||||
|
print(f"[INFO] Loading embeddings from: {label_path}")
|
||||||
|
dataset = load_from_disk(label_path)
|
||||||
|
print(f"[INFO] Label: {label_name}, Samples: {len(dataset)}")
|
||||||
|
datasets.append(dataset)
|
||||||
|
|
||||||
|
# Concatenate all datasets
|
||||||
|
if len(datasets) == 1:
|
||||||
|
combined_dataset = datasets[0]
|
||||||
|
else:
|
||||||
|
combined_dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
|
print(f"[INFO] Combined dataset: {len(combined_dataset)} total samples")
|
||||||
|
|
||||||
|
# Print label distribution
|
||||||
|
if "label" in combined_dataset.column_names:
|
||||||
|
label_counts = {}
|
||||||
|
for label in combined_dataset["label"]:
|
||||||
|
label_counts[label] = label_counts.get(label, 0) + 1
|
||||||
|
print(f"[INFO] Label distribution: {label_counts}")
|
||||||
|
|
||||||
|
return combined_dataset
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Visualization Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def reduce_to_2d_tsne(
|
||||||
|
embeddings: np.ndarray,
|
||||||
|
perplexity: float = 30.0
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Reduce high-dimensional embeddings to 2D using t-SNE.
|
||||||
|
|
||||||
|
t-SNE (t-distributed Stochastic Neighbor Embedding) is a nonlinear
|
||||||
|
dimensionality reduction technique that preserves local structure.
|
||||||
|
Points that are similar in high dimensions stay close in 2D.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: High-dimensional array of shape (num_samples, embedding_dim)
|
||||||
|
perplexity: t-SNE perplexity parameter (typically 5-50).
|
||||||
|
Higher values consider more neighbors, creating smoother layouts.
|
||||||
|
Rule of thumb: perplexity ~ sqrt(num_samples)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
2D coordinates of shape (num_samples, 2)
|
||||||
|
"""
|
||||||
|
print(f"[INFO] Running t-SNE with perplexity={perplexity}...")
|
||||||
|
|
||||||
|
tsne = TSNE(
|
||||||
|
n_components=2,
|
||||||
|
random_state=0, # For reproducibility
|
||||||
|
perplexity=perplexity,
|
||||||
|
max_iter=1000, # Usually sufficient for convergence
|
||||||
|
init='random',
|
||||||
|
learning_rate='auto', # Let sklearn choose optimal learning rate
|
||||||
|
)
|
||||||
|
|
||||||
|
return tsne.fit_transform(embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def create_scatter_plot_by_age(
|
||||||
|
coordinates: np.ndarray,
|
||||||
|
ages: np.ndarray,
|
||||||
|
title: str,
|
||||||
|
output_path: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create and save a 2D scatter plot colored by age (continuous colormap).
|
||||||
|
|
||||||
|
Uses a continuous colormap (viridis) to show age distribution.
|
||||||
|
Ages are mapped to colors on a gradient scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coordinates: 2D array of shape (num_points, 2)
|
||||||
|
ages: Age values for each point (array, may contain None values)
|
||||||
|
title: Plot title
|
||||||
|
output_path: File path to save the plot (PDF vector format recommended)
|
||||||
|
"""
|
||||||
|
# Filter out points with missing age data
|
||||||
|
# Convert None values to NaN for proper numpy handling
|
||||||
|
ages_float = np.array([float(a) if a is not None else np.nan for a in ages])
|
||||||
|
valid_mask = ~np.isnan(ages_float)
|
||||||
|
valid_coords = coordinates[valid_mask]
|
||||||
|
valid_ages = ages_float[valid_mask]
|
||||||
|
|
||||||
|
if len(valid_ages) == 0:
|
||||||
|
print(f"[WARN] No valid age data found. Skipping plot: {output_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create figure
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 8))
|
||||||
|
|
||||||
|
# Create scatter plot with continuous colormap
|
||||||
|
scatter = ax.scatter(
|
||||||
|
valid_coords[:, 0],
|
||||||
|
valid_coords[:, 1],
|
||||||
|
c=valid_ages,
|
||||||
|
cmap='viridis', # Continuous colormap for granular age visualization
|
||||||
|
s=15,
|
||||||
|
alpha=0.7,
|
||||||
|
edgecolors='none',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add colorbar
|
||||||
|
cbar = plt.colorbar(scatter, ax=ax)
|
||||||
|
cbar.set_label('Age (years)', rotation=270, labelpad=20)
|
||||||
|
|
||||||
|
# Configure plot appearance
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.set_xlabel("Dimension 1")
|
||||||
|
ax.set_ylabel("Dimension 2")
|
||||||
|
|
||||||
|
# Save figure as vector PDF
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_path, format='pdf', bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
print(f"[DONE] Saved plot: {output_path}")
|
||||||
|
print(f"[INFO] Age range: {valid_ages.min():.0f} - {valid_ages.max():.0f} years")
|
||||||
|
print(f"[INFO] Points with valid age: {len(valid_ages)}/{len(ages)}")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Command Line Interface
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class CLI:
|
||||||
|
"""
|
||||||
|
Command-line interface for SBERT metadata-based embedding extraction and visualization.
|
||||||
|
|
||||||
|
Provides two main commands:
|
||||||
|
- extract: Generate embeddings from user metadata (age, sex)
|
||||||
|
- plot: Visualize embeddings with t-SNE colored by age
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract(
|
||||||
|
self,
|
||||||
|
data_root: str = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF_full",
|
||||||
|
subject_path: str = SUBJECT_PATH,
|
||||||
|
out_dir: str = "./embeddings/all_labels",
|
||||||
|
batch_size: int = 32,
|
||||||
|
label: str = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Extract embeddings from user metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root: Root directory containing user/session data folders
|
||||||
|
subject_path: Path to SC-subjects.xls file with metadata
|
||||||
|
out_dir: Output directory for HuggingFace dataset
|
||||||
|
batch_size: Batch size for inference (default: 32)
|
||||||
|
label: Sleep stage label to filter (e.g., 'W', 'N1', 'N2', 'N3', 'REM').
|
||||||
|
If None or 'all', processes all labels (default: None for all labels)
|
||||||
|
"""
|
||||||
|
embedder = SBERT_Metadata()
|
||||||
|
dataset = embedder.extract_embeddings(data_root, subject_path, batch_size, label)
|
||||||
|
embedder.save_embeddings(dataset, out_dir)
|
||||||
|
|
||||||
|
def plot(
|
||||||
|
self,
|
||||||
|
emb_dir: str = None,
|
||||||
|
embeddings_root: str = "./embeddings",
|
||||||
|
out_dir: str = "./plots/all_labels",
|
||||||
|
perplexity: float = 10.0,
|
||||||
|
users: str = None,
|
||||||
|
num_users: int = 0,
|
||||||
|
labels: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Visualize embeddings with t-SNE, colored by age.
|
||||||
|
|
||||||
|
Can load from either a single label directory or all label directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emb_dir: Single directory containing the HuggingFace embeddings dataset
|
||||||
|
(e.g., "./embeddings/REM"). If provided, only this directory is used.
|
||||||
|
embeddings_root: Root directory containing label subdirectories
|
||||||
|
(e.g., "./embeddings" containing W/, REM/, N1/, etc.).
|
||||||
|
Used only if emb_dir is not provided.
|
||||||
|
out_dir: Output directory for visualization plots (PDF)
|
||||||
|
perplexity: t-SNE perplexity parameter (default: 10.0)
|
||||||
|
users: Comma-separated user IDs to include (e.g., '00,01,02')
|
||||||
|
num_users: Include only first N users, 0 = all (default: 0)
|
||||||
|
labels: Comma-separated sleep stage labels to include (e.g., 'W,N1,N2')
|
||||||
|
This filters the already-loaded data, not which directories to load.
|
||||||
|
"""
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Load embeddings: either from single directory or all label directories
|
||||||
|
if emb_dir is not None:
|
||||||
|
# Load from single directory
|
||||||
|
print(f"[INFO] Loading embeddings from single directory: {emb_dir}")
|
||||||
|
dataset = SBERT_Metadata.load_embeddings(emb_dir)
|
||||||
|
else:
|
||||||
|
# Load from all label directories
|
||||||
|
print(f"[INFO] Loading embeddings from all label directories in: {embeddings_root}")
|
||||||
|
dataset = load_embeddings_from_all_labels(embeddings_root)
|
||||||
|
|
||||||
|
# Apply user filtering
|
||||||
|
if users:
|
||||||
|
user_list = [u.strip() for u in users.split(",")]
|
||||||
|
dataset = dataset.filter(lambda x: x["user_id"] in user_list)
|
||||||
|
print(f"[INFO] Filtered to users: {user_list}")
|
||||||
|
|
||||||
|
elif num_users > 0:
|
||||||
|
all_users = sorted(set(dataset["user_id"]))
|
||||||
|
selected_users = all_users[:num_users]
|
||||||
|
dataset = dataset.filter(lambda x: x["user_id"] in selected_users)
|
||||||
|
print(f"[INFO] Selected first {num_users} users: {selected_users}")
|
||||||
|
|
||||||
|
# Filter by sleep stage labels
|
||||||
|
if labels:
|
||||||
|
label_list = [l.strip() for l in labels.split(",")]
|
||||||
|
dataset = dataset.filter(lambda x: x["label"] in label_list)
|
||||||
|
print(f"[INFO] Filtered to labels: {label_list}")
|
||||||
|
|
||||||
|
print(f"[INFO] Total samples: {len(dataset)}")
|
||||||
|
|
||||||
|
# Extract embeddings as numpy array for t-SNE
|
||||||
|
embeddings = np.array(dataset["embedding"])
|
||||||
|
|
||||||
|
# Extract ages for coloring
|
||||||
|
ages = np.array(dataset["age"])
|
||||||
|
|
||||||
|
# Reduce to 2D with t-SNE
|
||||||
|
coordinates_2d = reduce_to_2d_tsne(embeddings, perplexity)
|
||||||
|
|
||||||
|
# Generate visualization colored by age
|
||||||
|
create_scatter_plot_by_age(
|
||||||
|
coordinates_2d,
|
||||||
|
ages,
|
||||||
|
"t-SNE Visualization (Colored by Age)",
|
||||||
|
os.path.join(out_dir, "tsne_by_age.pdf")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Fire(CLI)
|
||||||
Reference in New Issue
Block a user