initial codebase for llm-based sensing

This commit is contained in:
Hyungjun Yoon
2025-10-23 16:07:29 +09:00
parent 41024838f0
commit 59b14b5d9f
11 changed files with 1413 additions and 0 deletions

127
analysis/analyze_data.ipynb Normal file

File diff suppressed because one or more lines are too long

9
config/sleepedf.yaml Normal file
View File

@@ -0,0 +1,9 @@
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
num_seeds: 10
models:
- ollama:url:rose.kaist.ac.kr:11437/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11438/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11439/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11440/gpt-oss:20b
- ollama:url:rose.kaist.ac.kr:11441/gpt-oss:20b
log_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF"

208
core/agent.py Normal file
View File

@@ -0,0 +1,208 @@
import os
import re
import json
import tiktoken
from langchain_ollama import ChatOllama
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
class Agent:
def __init__(
self,
name,
model_pool,
log_path,
):
self.name = name
self.model_pool = model_pool
self.log_path = log_path
self.root_log_path = log_path
self.agent_log_path = os.path.join(log_path, name)
os.makedirs(self.agent_log_path, exist_ok=True)
self.long_term_memory = []
self.short_term_memory = []
self.volatile_memory = []
self.total_input_tokens = 0
self.total_output_tokens = 0
self.total_tokens = 0
self.total_calls = 0
def log(self, message, local=True):
path = os.path.join(self.root_log_path, "log.txt")
with open(path, "a", encoding="utf-8") as f:
message_type = "UNKNOWN"
if isinstance(message, SystemMessage):
message_type = "SYSTEM"
if isinstance(message, HumanMessage):
message_type = "HUMAN"
if isinstance(message, AIMessage):
message_type = "AI"
content = message.content.strip()
name = self.name
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
if local:
local_path = os.path.join(self.agent_log_path, "log.txt")
with open(local_path, "a", encoding="utf-8") as f:
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
def log_tokens(self, messages, response):
input_tokens = 0
for msg in messages:
msg_tokens = self.count_tokens(msg.content)
input_tokens += msg_tokens
self.total_tokens += msg_tokens
self.total_input_tokens += msg_tokens
output_tokens = self.count_tokens(response.content)
self.total_tokens += output_tokens
self.total_output_tokens += output_tokens
self.total_calls += 1
path = os.path.join(self.agent_log_path, "tokens.txt")
with open(path, "a", encoding="utf-8") as f:
f.write(f"Input tokens: {input_tokens}\n")
f.write(f"Output tokens: {output_tokens}\n")
f.write(f"Total input tokens: {self.total_input_tokens}\n")
f.write(f"Total output tokens: {self.total_output_tokens}\n")
f.write(f"Total tokens: {self.total_tokens}\n")
f.write(f"Total calls: {self.total_calls}\n")
f.write("\n")
def count_tokens(self, text, model="gpt-3.5-turbo"):
enc = tiktoken.encoding_for_model(model)
return len(enc.encode(text))
def update_memory(self):
self.long_term_memory.extend(self.short_term_memory)
self.clean_short_term_memory()
self.clean_volatile_memory()
def clean_short_term_memory(self):
self.short_term_memory = []
def clean_volatile_memory(self):
self.volatile_memory = []
def clean_long_term_memory(self):
self.long_term_memory = []
def clean_json_text(self, text):
text = text.strip()
text = text.replace("", "'").replace("", "'")
text = text.replace("", "'").replace("", "'")
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text)
text = re.sub(r",\s*}", "}", text)
text = re.sub(r",\s*]", "]", text)
text = "".join(ch for ch in text if ch.isprintable())
text = text.replace("][", ",")
return text
def safe_parse_json(self, text):
text = text.strip()
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
elif not text.endswith("}"):
text += "}"
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
print("[!] JSON parse failed")
return None
def safe_parse_json_list(self, text):
text = text.strip()
match = re.search(r"\[.*\]", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
elif not text.endswith("]"):
text += "]"
match = re.search(r"\[.*\]", text, re.DOTALL)
if match:
text = match.group(0)
text = self.clean_json_text(text)
try:
return json.loads(text)
except json.JSONDecodeError as e:
print(f"[!] JSON parse failed: {e}")
return None
print("[!] JSON parse failed")
return None
async def validate_response(self, response, fields, volatile=False):
if (
not response
or not isinstance(response, dict)
or not all(field in response for field in fields)
):
print("[!] The JSON failed to be parsed. Trying again.")
content = (
"Failed to parse the JSON from the previous response. Please try again."
)
response = await self.invoke(content, volatile=volatile)
response = self.safe_parse_json(response)
if (
not response
or not isinstance(response, dict)
or not all(field in response for field in fields)
):
print("[!] Retry failed.")
return None
return response
def get_last_response(self):
if len(self.long_term_memory) >= 2:
last_msg = self.long_term_memory[-1]
if isinstance(last_msg, AIMessage):
return self.safe_parse_json(last_msg.content)
return None
def set_system_message(self, content, local=True):
system_message = SystemMessage(content=content)
self.log(system_message, local)
self.long_term_memory.append(system_message)
async def invoke(self, content, volatile=False, local=True):
messages = self.long_term_memory.copy()
if volatile:
messages.extend(self.volatile_memory)
else:
messages.extend(self.short_term_memory)
messages.append(HumanMessage(content=content))
try:
response = await self.model_pool.invoke(messages)
self.log_tokens(messages, response)
if volatile:
self.volatile_memory.extend([HumanMessage(content=content), response])
else:
self.short_term_memory.extend([HumanMessage(content=content), response])
local_ = not volatile and local
self.log(HumanMessage(content=content), local=local_)
self.log(response, local=local_)
return response.content.strip()
except Exception as e: # pylint: disable=broad-exception-caught
print(f"[Error] Error occurred while invoking LLM: {e}")

78
core/data_loader.py Normal file
View File

@@ -0,0 +1,78 @@
import os
import json
import datasets
import numpy as np
from glob import glob
class DataLoader:
def __init__(self, data_path, user_id, selection_criteria="out_random", num_examples=1):
self.is_valid = False
if not os.path.exists(os.path.join(data_path, "info.json")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
return
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
self.example_dataset = datasets.Dataset.from_list([])
users = glob(os.path.join(data_path, "*"))
users = [path.split("/")[-1] for path in users]
if "info.json" in users:
users.remove("info.json")
for user in users:
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
self.test_dataset = self.test_dataset.shuffle(seed=0)
self.example_dataset = self.example_dataset.shuffle(seed=0)
self.user_id = user_id
self.selection_criteria = selection_criteria
self.num_examples = num_examples
self.selected_examples = self.sample_examples()
self.is_valid = True
def __len__(self):
return len(self.test_dataset)
def __getitem__(self, idx):
return self.test_dataset[idx], self.selected_examples
def __iter__(self):
for sample in self.test_dataset:
yield sample, self.selected_examples
def sample_examples(self):
classes = self.test_dataset.unique("label")
example_dataset = datasets.Dataset.from_list([])
if self.selection_criteria == "out_random":
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] != user_id)
for c in classes:
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
elif self.selection_criteria == "in_random":
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] == user_id)
for c in classes:
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
return example_dataset
def get_metadata(self):
return self.metadata
def get_sensor_info(self):
return self.metadata["feature"]
def get_task_info(self):
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
classes_info = "\n".join(classes_info)
task_info += f"**Classes**:\n{classes_info}"
return task_info
def get_classes_info(self):
classes_info = [k for k in self.metadata["class"].keys()]
return classes_info

95
core/model.py Normal file
View File

@@ -0,0 +1,95 @@
import asyncio
from langchain_ollama import ChatOllama
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage
def load_models(models):
model_pool = AsyncModelPool()
for model in models:
model_pool.add_model(Model(model))
model_pool.init_models()
return model_pool
class Model:
def __init__(self, model, temperature=0.7):
if model.startswith("ollama:"):
model = model.replace("ollama:", "")
if "url:" in model:
model = model.replace("url:", "")
base_url = model.split("/")[0]
if not base_url.startswith("http"):
base_url = "http://" + base_url
model_type = model.split("/")[1]
self.model = ChatOllama(
model=model_type,
base_url=base_url,
temperature=temperature,
num_ctx=12000,
)
else:
self.model = ChatOllama(
model=model.replace("ollama:", ""),
temperature=temperature,
num_ctx=12000,
)
else:
self.model = init_chat_model(
model=model,
temperature=temperature,
)
def invoke(self, messages):
try:
response = self.model.invoke(messages)
return response
except Exception as e:
print(f"[Error] Error occurred while invoking LLM: {e}")
return e
class AsyncModel:
def __init__(self, model):
self.model = model
async def invoke(self, content):
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: self.model.invoke(content),
)
return response
class AsyncModelPool:
def __init__(self):
self.models = []
self._available_models = None
self._model_semaphore = None
def add_model(self, model):
self.models.append(model)
def init_models(self):
# Initialize the queue and semaphore in the current event loop
self._available_models = asyncio.Queue()
for model in self.models:
async_model = AsyncModel(model)
self._available_models.put_nowait(async_model)
self._model_semaphore = asyncio.Semaphore(len(self.models))
# Test each model
for model in self.models:
model.invoke([HumanMessage(content="Hello world!")])
async def invoke(self, content):
if self._available_models is None:
raise RuntimeError("Model pool not initialized. Call init_models() first.")
async_model = await self._available_models.get()
try:
response = await async_model.invoke(content)
return response
finally:
self._available_models.put_nowait(async_model)

186
core/sensing_agent.py Normal file
View File

@@ -0,0 +1,186 @@
import json
import copy
import os
from .agent import Agent
class SensingAgent(Agent):
def __init__(
self,
name,
model_pool,
task_info,
classes_info,
sensor_info,
sample,
examples,
log_path,
):
super().__init__(
name=name,
model_pool=model_pool,
log_path=log_path,
)
self.task_info = task_info
self.classes_info = classes_info
self.sensor_info = sensor_info
self.sample = sample
self.examples = examples
self.init_system_message()
def init_system_message(self):
content = (
f"You are {self.name} agent that interprets sensor data to solve a task.\n"
"You have the following information about the task:\n"
f"{self.task_info}\n\n"
"You have the following information about the sensor data:\n"
f"{self.sensor_info}\n\n"
"Your goal is to analyze the features and "
"provide a reasoned answer using your knowledge."
)
self.set_system_message(content)
def gen_feature_info(self):
feature_info = f"{self.name} features:\n"
if len(self.examples) > 0:
feature_info += f"{self.gen_example_info()}\n\n"
feature_info += "**Current sample features**:\n"
for k, v in self.sample["features"].items():
# print(k, v)
# assert 0
# modalities = self.name.split("+")
# for modality in modalities:
# if modality in k:
feature_info += f" - {k}: {self.format_feature(v)}\n"
feature_info = feature_info.strip()
return feature_info
def gen_example_info(self):
example_info = (
"**Examples**\n"
"Sensor values might not always align with your inherent "
"knowledge due to differences in data collection or processing. "
"So, we included a few labeled examples to help your interpretation:\n"
)
for example in self.examples:
example_info += f"*Example of {example['label']}*:\n"
for k, v in example["features"].items():
# modalities = self.name.split("+")
# for modality in modalities:
# if modality in k:
example_info += f" - {k}: {self.format_feature(v)}\n"
example_info += "\n"
example_info = example_info.strip()
return example_info
def format_feature(self, value):
if isinstance(value, float):
if abs(value) >= 1e4 or abs(value) < 1e-2:
return f"{value:.2e}"
return f"{value:.2f}"
return value
async def solve(self, sample, examples, ground_truth):
self.sample = sample
self.examples = examples
feature_info = self.gen_feature_info()
content = (
f"You have received sensor features from {self.name} modality:\n"
f"{feature_info}\n\n"
f"Please provide your answer for the task among {self.classes_info} "
"and the reasoning for your answer.\n"
"Note that the sensor features might be wrong due to the data collection or processing.\n"
"You can evaluate the quality of the features by checking the examples you have.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Reasoning for the answer>",\n'
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "ANSWER"]
)
self.clean_short_term_memory()
self.clean_long_term_memory()
answer = parsed_response["ANSWER"]
if answer == ground_truth:
print(f"Correct answer: {answer}")
else:
print(f"Incorrect answer: {answer} (Ground truth: {ground_truth})")
return parsed_response
async def interpret(self):
feature_info = self.gen_feature_info()
content = (
f"You have received sensor features from {self.name} modality:\n"
f"{feature_info}\n\n"
f"Please provide your answer for the task among {self.classes_info} "
"and the reasoning for your answer.\n"
"Note that the sensor features might be wrong due to the data collection or processing.\n"
"You can evaluate the quality of the features by checking the examples you have.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Reasoning for the answer>",\n'
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "ANSWER"]
)
return parsed_response
async def evaluate(self, target_name, initial_response):
initial_response_info = json.dumps(initial_response, indent=2)
content = (
f"Other agent, <{target_name}> provided the following answer for the same task:\n"
f"{initial_response_info}\n\n"
"Please evaluate the given reasoning and answer based on your judgement. "
"You may either support with it or disagree.\n"
"If you agree, explain why the reasoning and answer are valid. "
"If you disagree, explain why the reasoning or answer may be flawed, "
f"and provide constructive feedback on how <{target_name}> can improve its response.\n"
"Respond in the following strict JSON format:\n"
"{\n"
f' "EVALUATION": "<Evaluation to <{target_name}>"\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content, volatile=True)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["EVALUATION"], volatile=True
)
self.clean_volatile_memory()
return parsed_response
async def reflect(self, evaluations):
evaluations_info = json.dumps(evaluations, indent=2)
content = (
f"Other agents have evaluated your answer for the same task:\n"
f"{evaluations_info}\n\n"
"Please reflect on the evaluations and provide a refined answer for the same task.\n"
"Respond in the following strict JSON format:\n"
"{\n"
' "REASON": "<Reasoning for the answer>",\n'
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
"}\n\n"
"Do not include any additional text outside the JSON."
)
response = await self.invoke(content)
parsed_response = self.safe_parse_json(response)
parsed_response = await self.validate_response(
parsed_response, ["REASON", "ANSWER"]
)
return parsed_response

144
preprocess/dhedfreader.py Normal file
View File

@@ -0,0 +1,144 @@
"""
Credits to: https://github.com/emadeldeen24/TS-TCC/blob/main/data_preprocessing/sleep-edf/dhedfreader.py
"""
import re, datetime, operator, logging, sys
import numpy as np
from collections import namedtuple
EVENT_CHANNEL = "EDF Annotations"
log = logging.getLogger(__name__)
class EDFEndOfData(Exception):
pass
def tal(tal_str):
"""Return a list with (onset, duration, annotation) tuples for an EDF+ TAL
stream.
"""
exp = (
"(?P<onset>[+\-]\d+(?:\.\d*)?)"
+ "(?:\x15(?P<duration>\d+(?:\.\d*)?))?"
+ "(\x14(?P<annotation>[^\x00]*))?"
+ "(?:\x14\x00)"
)
def annotation_to_list(annotation):
return str(annotation.encode("utf-8")).split("\x14") if annotation else []
def parse(dic):
return (
float(dic["onset"]),
float(dic["duration"]) if dic["duration"] else 0.0,
annotation_to_list(dic["annotation"]),
)
return [parse(m.groupdict()) for m in re.finditer(exp, tal_str)]
def edf_header(f):
h = {}
assert f.tell() == 0 # check file position
assert f.read(8) == "0 "
# recording info)
h["local_subject_id"] = f.read(80).strip()
h["local_recording_id"] = f.read(80).strip()
# parse timestamp
(day, month, year) = [int(x) for x in re.findall("(\d+)", f.read(8))]
(hour, minute, sec) = [int(x) for x in re.findall("(\d+)", f.read(8))]
h["date_time"] = str(datetime.datetime(year + 2000, month, day, hour, minute, sec))
# misc
header_nbytes = int(f.read(8))
subtype = f.read(44)[:5]
h["EDF+"] = subtype in ["EDF+C", "EDF+D"]
h["contiguous"] = subtype != "EDF+D"
h["n_records"] = int(f.read(8))
h["record_length"] = float(f.read(8)) # in seconds
nchannels = h["n_channels"] = int(f.read(4))
# read channel info
channels = range(h["n_channels"])
h["label"] = [f.read(16).strip() for n in channels]
h["transducer_type"] = [f.read(80).strip() for n in channels]
h["units"] = [f.read(8).strip() for n in channels]
h["physical_min"] = np.asarray([float(f.read(8)) for n in channels])
h["physical_max"] = np.asarray([float(f.read(8)) for n in channels])
h["digital_min"] = np.asarray([float(f.read(8)) for n in channels])
h["digital_max"] = np.asarray([float(f.read(8)) for n in channels])
h["prefiltering"] = [f.read(80).strip() for n in channels]
h["n_samples_per_record"] = [int(f.read(8)) for n in channels]
f.read(32 * nchannels) # reserved
# assert f.tell() == header_nbytes
return h
class BaseEDFReader:
def __init__(self, file):
self.file = file
def read_header(self):
self.header = h = edf_header(self.file)
# calculate ranges for rescaling
self.dig_min = h["digital_min"]
self.phys_min = h["physical_min"]
phys_range = h["physical_max"] - h["physical_min"]
dig_range = h["digital_max"] - h["digital_min"]
# assert np.all(phys_range > 0)
# assert np.all(dig_range > 0)
self.gain = phys_range / dig_range
def read_raw_record(self):
"""Read a record with data_2013 and return a list containing arrays with raw
bytes.
"""
result = []
for nsamp in self.header["n_samples_per_record"]:
samples = self.file.read(nsamp * 2)
if len(samples) != nsamp * 2:
raise EDFEndOfData
result.append(samples)
return result
def convert_record(self, raw_record):
"""Convert a raw record to a (time, signals, events) tuple based on
information in the header.
"""
h = self.header
dig_min, phys_min, gain = self.dig_min, self.phys_min, self.gain
time = float("nan")
signals = []
events = []
for i, samples in enumerate(raw_record):
if h["label"][i] == EVENT_CHANNEL:
ann = tal(samples)
time = ann[0][0]
events.extend(ann[1:])
# print(i, samples)
# exit()
else:
# 2-byte little-endian integers
dig = np.fromstring(samples, "<i2").astype(np.float32)
phys = (dig - dig_min[i]) * gain[i] + phys_min[i]
signals.append(phys)
return time, signals, events
def read_record(self):
return self.convert_record(self.read_raw_record())
def records(self):
"""
Record generator.
"""
try:
while True:
yield self.read_record()
except EDFEndOfData:
pass

View File

@@ -0,0 +1,479 @@
import os
import math
import json
import warnings
import numpy as np
import neurokit2 as nk
from tqdm import tqdm
from fire import Fire
from glob import glob
from mne.io import read_raw_edf # pylint: disable=no-name-in-module
from scipy import signal
from datasets import Dataset, concatenate_datasets
from datetime import datetime
from multiprocessing import Pool
from dhedfreader import BaseEDFReader
from neurokit2.misc._warnings import NeuroKitWarning
warnings.simplefilter("ignore", NeuroKitWarning)
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
SLEEPEDF_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/sleep-cassette/"
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
EPOCH_SEC_SIZE = 30
SAMPLING_RATE = 100
W = 0
N1 = 1
N2 = 2
N3 = 3
REM = 4
UNKNOWN = 5
class_dict = {0: "W", 1: "N1", 2: "N2", 3: "N3", 4: "REM", 5: "UNKNOWN"}
ann2label = {
"Sleep stage W": 0,
"Sleep stage 1": 1,
"Sleep stage 2": 2,
"Sleep stage 3": 3,
"Sleep stage 4": 3,
"Sleep stage R": 4,
"Sleep stage ?": 5,
"Movement time": 5,
}
def store_info(info_path):
info = {
"task": 'Classify the user\'s sleep stage: ["W", "N1", "N2", "N3", "REM"], based on physiological signals collected from wearable sensors.',
"class": {
"W": "Wakefulness. This includes periods before sleep onset or after final awakening, and short awakenings during the night.",
"N1": "Stage N1 (light sleep). This is the transition stage between wakefulness and deeper sleep.",
"N2": "Stage N2 (intermediate sleep). This stage typically follows N1 and occurs multiple times throughout the night.",
"N3": "Stage N3 (deep sleep). This stage is associated with the most restful sleep and occurs mostly in the first half of the night.",
"REM": "Rapid Eye Movement (REM) sleep. This stage is associated with dreaming and typically occurs cyclically later in the night.",
},
"data": (
"Data were recorded during overnight sleep using multiple sensors. \n"
"EEG signals were recorded from two electrode pairs, Fpz-Cz and Pz-Oz. "
"Both EEG signals (in micro volt) were sampled at 100 Hz.\n"
"All signals were time-synchronized, and features were computed using a fixed 30-second window.\n"
"Each feature is named using the format 'modality_featurename', where modality refers to the type of signal "
"(e.g., EEG-Fpz-Cz, EEG-Pz-Oz) and featurename describes the extracted characteristic (e.g., delta_power)."
),
"feature": (
"For EEG signals, a bandpass filter was applied to extract key frequency components: delta (0.5-4 Hz), "
"theta (4-8 Hz), alpha (8-12 Hz), beta (12-30 Hz), spindle (12-14 Hz), k-complex (0.5-1.5 Hz), and sawtooth (2-6 Hz). "
"For each band, time-domain features such as mean, standard deviation (std), variance, dynamic range, number of peaks, "
"number of zero-crossings, and variance of the first-order difference were extracted.\n"
"Additionally, power spectral density was estimated using Welch's method to compute absolute power for each band. "
"Ratio features such as delta/theta, theta/alpha, alpha/beta, and (delta+theta)/(alpha+beta) were also included."
),
}
with open(info_path, "w", encoding="utf-8") as f:
json.dump(info, f, indent=2)
def lowpass_filter(data, cutoff=50, fs=1000, order=4):
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
b, a = signal.butter(order, normal_cutoff, btype="low", analog=False)
return signal.filtfilt(b, a, data)
def count_large_eye_movements(
eog_cleaned,
sr,
amp_thresh=120,
time_thresh=1.5,
):
time_limit = int(time_thresh * sr)
peaks, _ = signal.find_peaks(eog_cleaned)
troughs, _ = signal.find_peaks(-eog_cleaned)
events = np.sort(np.concatenate((peaks, troughs)))
count = 0
for i in range(len(events) - 1):
t1, t2 = events[i], events[i + 1]
if abs(t2 - t1) <= time_limit:
amp = abs(eog_cleaned[t2] - eog_cleaned[t1])
if amp >= amp_thresh:
count += 1
return count
def remove_large_eye_movements(
eog_cleaned,
fs=100,
amp_thresh=120,
time_thresh=1.5,
pad=0.75,
):
time_limit = int(time_thresh * fs)
pad_samples = int(pad * fs)
peaks, _ = signal.find_peaks(eog_cleaned)
troughs, _ = signal.find_peaks(-eog_cleaned)
events = np.sort(np.concatenate((peaks, troughs)))
mask = np.ones_like(eog_cleaned, dtype=bool)
for i in range(len(events) - 1):
t1, t2 = events[i], events[i + 1]
if abs(t2 - t1) <= time_limit:
amp = abs(eog_cleaned[t2] - eog_cleaned[t1])
if amp >= amp_thresh:
start = max(0, min(t1, t2) - pad_samples)
end = min(len(eog_cleaned), max(t1, t2) + pad_samples)
mask[start:end] = False
cleaned_signal = eog_cleaned.copy()
cleaned_signal[~mask] = np.mean(eog_cleaned[mask])
return cleaned_signal
def bandpass_filter(data, fs, band, order=4):
nyq = 0.5 * fs
low = band[0] / nyq
high = band[1] / nyq
b, a = signal.butter(order, [low, high], btype="band")
return signal.filtfilt(b, a, data)
def process_by_mod(modality, data):
sr = SAMPLING_RATE
features = {}
modality = "-".join(modality.split())
if "EEG" in modality:
band_data = data
signal_modality = modality
mean_val = np.mean(band_data)
std_val = np.std(band_data)
var_val = np.var(band_data)
features[f"{signal_modality}_mean"] = mean_val
features[f"{signal_modality}_std"] = std_val
features[f"{signal_modality}_variance"] = var_val
dynamic_range = np.max(band_data) - np.min(band_data)
features[f"{signal_modality}_dynamic_range"] = dynamic_range
peaks = signal.find_peaks(band_data - mean_val, height=3 * std_val)[0]
features[f"{signal_modality}_num_peaks"] = len(peaks)
zero_crossings = np.where(np.diff(np.sign(band_data - mean_val)))[0]
features[f"{signal_modality}_num_zero_crossings"] = len(zero_crossings)
differences = band_data[1:] - band_data[:-1]
difference_variance = np.var(differences)
features[f"{signal_modality}_difference_variance"] = difference_variance
freqs, psd = signal.welch(data, fs=sr, nperseg=sr * 2)
delta_idx = np.logical_and(freqs >= 0.5, freqs <= 4)
theta_idx = np.logical_and(freqs >= 4, freqs <= 8)
alpha_idx = np.logical_and(freqs >= 8, freqs <= 12)
beta_idx = np.logical_and(freqs >= 12, freqs <= 30)
spindle_idx = np.logical_and(freqs >= 12, freqs <= 14)
kcomplex_idx = np.logical_and(freqs >= 0.5, freqs <= 1.5)
sawtooth_idx = np.logical_and(freqs >= 2, freqs <= 6)
delta_power = np.trapezoid(psd[delta_idx], freqs[delta_idx])
theta_power = np.trapezoid(psd[theta_idx], freqs[theta_idx])
alpha_power = np.trapezoid(psd[alpha_idx], freqs[alpha_idx])
beta_power = np.trapezoid(psd[beta_idx], freqs[beta_idx])
spindle_power = np.trapezoid(psd[spindle_idx], freqs[spindle_idx])
kcomplex_power = np.trapezoid(psd[kcomplex_idx], freqs[kcomplex_idx])
sawtooth_power = np.trapezoid(psd[sawtooth_idx], freqs[sawtooth_idx])
delta_theta_ratio = delta_power / theta_power if theta_power > 0 else 0
theta_alpha_ratio = theta_power / alpha_power if alpha_power > 0 else 0
alpha_beta_ratio = alpha_power / beta_power if beta_power > 0 else 0
slow_fast_ratio = (
(delta_power + theta_power) / (alpha_power + beta_power)
if (alpha_power + beta_power) > 0
else 0
)
features[f"{modality}_delta_power"] = delta_power
features[f"{modality}_theta_power"] = theta_power
features[f"{modality}_alpha_power"] = alpha_power
features[f"{modality}_beta_power"] = beta_power
features[f"{modality}_spindle_power"] = spindle_power
features[f"{modality}_kcomplex_power"] = kcomplex_power
features[f"{modality}_sawtooth_power"] = sawtooth_power
features[f"{modality}_delta/theta_ratio"] = delta_theta_ratio
features[f"{modality}_theta/alpha_ratio"] = theta_alpha_ratio
features[f"{modality}_alpha/beta_ratio"] = alpha_beta_ratio
features[f"{modality}_(delta+theta)/(alpha+beta)_ratio"] = slow_fast_ratio
# elif "EOG" in modality:
# eog_cleaned = data
# try:
# eog_cleaned = nk.eog_clean(data, sr)
# except IndexError as e:
# print(f"Error processing EOG data for {modality}: {e}")
# return None
# eog_mean = np.mean(eog_cleaned)
# eog_std = np.std(eog_cleaned)
# eog_var = np.var(eog_cleaned)
# features[f"{modality}_mean"] = eog_mean
# features[f"{modality}_std"] = eog_std
# features[f"{modality}_variance"] = eog_var
# dynamic_range = np.max(eog_cleaned) - np.min(eog_cleaned)
# features[f"{modality}_dynamic_range"] = dynamic_range
# peaks = signal.find_peaks(eog_cleaned - eog_mean, height=3 * eog_std)[0]
# features[f"{modality}_num_peaks"] = len(peaks)
# zero_crossings = np.where(np.diff(np.sign(eog_cleaned - eog_mean)))[0]
# features[f"{modality}_num_zero_crossings"] = len(zero_crossings)
# differences = eog_cleaned[1:] - eog_cleaned[:-1]
# difference_variance = np.var(differences)
# features[f"{modality}_difference_variance"] = difference_variance
# features[f"{modality}_num_large_eye_movements"] = count_large_eye_movements(
# eog_cleaned, sr, amp_thresh=120, time_thresh=1.5
# )
# eog_large_movement_removed = remove_large_eye_movements(
# eog_cleaned, fs=sr, amp_thresh=120, time_thresh=1.5, pad=0.75
# )
# differences = eog_large_movement_removed[1:] - eog_large_movement_removed[:-1]
# difference_variance = np.var(differences)
# features[f"{modality}_difference_variance_without_large_movements"] = (
# difference_variance
# )
# freqs, psd = signal.welch(eog_cleaned, fs=sr, nperseg=sr * 2)
# total_idx = np.logical_and(freqs >= 0.5, freqs <= 30)
# total_power = np.trapezoid(psd[total_idx], freqs[total_idx])
# slow_idx = np.logical_and(freqs >= 0.5, freqs <= 2)
# rapid_idx = np.logical_and(freqs >= 2, freqs <= 5)
# slow_power = np.trapezoid(psd[slow_idx], freqs[slow_idx])
# rapid_power = np.trapezoid(psd[rapid_idx], freqs[rapid_idx])
# slow_power_ratio = slow_power / total_power if total_power > 0 else 0
# rapid_power_ratio = rapid_power / total_power if total_power > 0 else 0
# features[f"{modality}_slow_movement_power_ratio"] = slow_power_ratio
# features[f"{modality}_rapid_movement_power_ratio"] = rapid_power_ratio
# elif "Resp" in modality:
# rsp_signals = data
# try:
# rsp_signals, _ = nk.rsp_process(data, sampling_rate=sr, method="biosppy")
# except IndexError as e:
# print(f"Error processing respiration data for {modality}: {e}")
# return None
# clean = rsp_signals["RSP_Clean"]
# phase = rsp_signals["RSP_Phase"]
# rate = rsp_signals["RSP_Rate"]
# amplitude = rsp_signals["RSP_Amplitude"]
# peaks = np.where(rsp_signals["RSP_Peaks"] == 1)[0]
# troughs = np.where(rsp_signals["RSP_Troughs"] == 1)[0]
# inhale_durations = []
# for t in troughs:
# next_peaks = peaks[peaks > t]
# if len(next_peaks) == 0:
# continue
# inhale_durations.append((next_peaks[0] - t) / sr)
# inhale_durations = np.array(inhale_durations)
# exhale_durations = []
# for p in peaks:
# next_troughs = troughs[troughs > p]
# if len(next_troughs) == 0:
# continue
# exhale_durations.append((next_troughs[0] - p) / sr)
# exhale_durations = np.array(exhale_durations)
# features[f"{modality}_inhale_duration_mean"] = np.mean(inhale_durations)
# features[f"{modality}_inhale_duration_std"] = np.std(inhale_durations)
# features[f"{modality}_exhale_duration_mean"] = np.mean(exhale_durations)
# features[f"{modality}_exhale_duration_std"] = np.std(exhale_durations)
# features[f"{modality}_inhale_exhale_ratio"] = (
# np.mean(inhale_durations) / np.mean(exhale_durations)
# if np.mean(exhale_durations) > 0
# else np.nan
# )
# features[f"{modality}_stretch"] = np.max(clean) - np.min(clean)
# inhale_mask = phase == 1
# features[f"{modality}_inspiration_volume"] = np.trapezoid(
# amplitude[inhale_mask], dx=1 / sr
# )
# features[f"{modality}_respiration_rate"] = np.mean(rate)
# resp_durations = np.diff(troughs) / sr
# features[f"{modality}_respiration_duration"] = np.mean(resp_durations)
# elif "EMG" in modality:
# emg_mean = np.mean(data)
# emg_std = np.std(data)
# features[f"{modality}_mean"] = emg_mean
# features[f"{modality}_std"] = emg_std
# features[f"{modality}_dynamic_range"] = np.max(data) - np.min(data)
# features[f"{modality}_absolute_integral"] = np.sum(np.abs(data)) / sr
# features[f"{modality}_median"] = np.median(data)
# features[f"{modality}_10th_percentile"] = np.percentile(data, 10)
# features[f"{modality}_90th_percentile"] = np.percentile(data, 90)
# peaks, _ = signal.find_peaks(data, height=3 * emg_std)
# peak_values = data[peaks]
# features[f"{modality}_num_peaks"] = len(peaks)
# features[f"{modality}_peak_amplitude_mean"] = (
# np.mean(peak_values) if len(peak_values) > 0 else 0
# )
# features[f"{modality}_peak_amplitude_std"] = (
# np.std(peak_values) if len(peak_values) > 0 else 0
# )
# features[f"{modality}_peak_amplitude_sum"] = (
# np.sum(peak_values) if len(peak_values) > 0 else 0
# )
# features[f"{modality}_peak_amplitude_norm_sum"] = (
# np.sum(peak_values) / np.sum(np.abs(data))
# if np.sum(np.abs(data)) > 0
# else 0
# )
return features
def preprocess(file_path):
psg_file_path = file_path[0]
ann_file_path = file_path[1]
raw = read_raw_edf(psg_file_path, preload=True, stim_channel=None)
sr = raw.info["sfreq"]
curr_user_id = psg_file_path.split("/")[-1].split("-")[0][3:5]
curr_session_id = psg_file_path.split("/")[-1].split("-")[0][5]
raw_ch_df = raw.to_data_frame()
raw_ch_df.set_index(np.arange(len(raw_ch_df)))
raw_ch_df.drop(columns=["time", "Temp rectal", "Event marker"], inplace=True)
column_list = raw_ch_df.columns.tolist()
column_list = [c.replace("_", "-") for c in column_list]
f = open(psg_file_path, "r", encoding="iso-8859-1")
reader_raw = BaseEDFReader(f)
reader_raw.read_header()
h_raw = reader_raw.header
f.close()
raw_start_dt = datetime.strptime(h_raw["date_time"], "%Y-%m-%d %H:%M:%S")
f = open(ann_file_path, "r", encoding="iso-8859-1")
reader_ann = BaseEDFReader(f)
reader_ann.read_header()
h_ann = reader_ann.header
_, _, ann = list(zip(*reader_ann.records()))
f.close()
ann_start_dt = datetime.strptime(h_ann["date_time"], "%Y-%m-%d %H:%M:%S")
assert raw_start_dt == ann_start_dt
remove_idx = []
labels = []
label_idx = []
for a in ann[0]:
onset_sec, duration_sec, ann_char = a
ann_str = "".join(ann_char)
ann_str = ann_str.strip("b'\"")
label = ann2label[ann_str]
if label != UNKNOWN:
if duration_sec % EPOCH_SEC_SIZE != 0:
assert False, "Duration should be multiple of 30 seconds"
duration_epoch = int(duration_sec / EPOCH_SEC_SIZE)
label_epoch = np.ones(duration_epoch, dtype=int) * label
labels.append(label_epoch)
idx = int(onset_sec * sr) + np.arange(duration_sec * sr, dtype=int)
label_idx.append(idx)
else:
idx = int(onset_sec * sr) + np.arange(duration_sec * sr, dtype=int)
remove_idx.append(idx)
labels = np.hstack(labels)
if len(remove_idx) > 0:
remove_idx = np.hstack(remove_idx)
select_idx = np.setdiff1d(np.arange(len(raw_ch_df)), remove_idx)
else:
select_idx = np.arange(len(raw_ch_df))
label_idx = np.hstack(label_idx)
select_idx = np.intersect1d(select_idx, label_idx)
if len(label_idx) > len(select_idx):
extra_idx = np.setdiff1d(label_idx, select_idx)
if np.all(extra_idx > select_idx[-1]):
n_trims = len(select_idx) % int(EPOCH_SEC_SIZE * sr)
n_label_trims = int(math.ceil(n_trims / (EPOCH_SEC_SIZE * sr)))
select_idx = select_idx[:-n_trims]
labels = labels[:-n_label_trims]
if len(labels) == 0:
print("No labels left after removing extra labels.")
return []
raw_ch = raw_ch_df.values[select_idx]
if len(raw_ch) % (EPOCH_SEC_SIZE * sr) != 0:
assert False, "Data length is not multiple of 30 seconds."
n_epochs = len(raw_ch) / (EPOCH_SEC_SIZE * sr)
x = np.asarray(np.split(raw_ch, n_epochs)).astype(np.float32)
y = labels.astype(np.int32)
assert len(x) == len(y)
w_edge_mins = 30
nw_idx = np.where(y != W)[0]
start_idx = nw_idx[0] - (w_edge_mins * 2)
end_idx = nw_idx[-1] + (w_edge_mins * 2)
if start_idx < 0:
start_idx = 0
if end_idx >= len(y):
end_idx = len(y) - 1
select_idx = np.arange(start_idx, end_idx + 1)
x = x[select_idx]
y = y[select_idx]
num_samples = x.shape[0]
num_modalities = x.shape[2]
data = []
pbar = tqdm(total=5 * num_samples, desc=f"Processing {curr_user_id}", leave=False)
idx = 0
indices = np.arange(num_samples)[::10]
for i in indices:
if y[i] == 5:
assert 0
else:
pbar.update(1)
w_features = {}
missing_w_features = {}
flag = False
for mod in range(num_modalities):
mod_features = process_by_mod(column_list[mod], x[i, :, mod])
if mod_features is None:
flag = True
break
for k, v in mod_features.items():
w_features[k] = v
if flag:
continue
data.append(
dict(
user_id=curr_user_id,
session_id=curr_session_id,
idx=idx,
label=class_dict[y[i]],
features=w_features,
)
)
idx += 1
pbar.close()
return data
def run(path=SLEEPEDF_PATH, out_dir=OUT_DIR, num_examples=1, num_workers=32, seed=0):
np.random.seed(seed)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
info_path = os.path.join(out_dir, "info.json")
store_info(info_path)
print(f"Saved info to {info_path}")
psg_file_paths = glob(os.path.join(path, "*PSG.edf"))
ann_file_paths = glob(os.path.join(path, "*Hypnogram.edf"))
psg_file_paths.sort()
ann_file_paths.sort()
file_paths = list(zip(psg_file_paths, ann_file_paths))
filtered_2013_file_paths = []
for file_path in file_paths:
basename = os.path.basename(file_path[0])
if basename.startswith("SC40"):
filtered_2013_file_paths.append(file_path)
elif basename.startswith("SC41"):
filtered_2013_file_paths.append(file_path)
with Pool(processes=num_workers) as pool:
for data in pool.imap_unordered(preprocess, filtered_2013_file_paths):
if len(data) == 0:
continue
user_id = data[0]["user_id"]
session_id = data[0]["session_id"]
dataset = Dataset.from_list(data)
test_dir = os.path.join(out_dir, f"{user_id}", f"{session_id}")
dataset.save_to_disk(test_dir)
print(f"Saved dataset to {test_dir}")
if __name__ == "__main__":
Fire(run)

49
run.py Normal file
View File

@@ -0,0 +1,49 @@
import os
import re
import asyncio
import yaml
import json
import numpy as np
from glob import glob
from fire import Fire
from core.model import load_models
from core.data_loader import DataLoader
from core.sensing_agent import SensingAgent
async def run_parallel(config, model_pool):
tasks = []
# for seed in range(config["num_seeds"]):
# users = glob(os.path.join(config["data_path"], "*"))
# users = [path for path in users if os.path.isdir(path)]
# for user in users:
data_loader = DataLoader(config["data_path"], "00", selection_criteria="in_random", num_examples=1)
# if not data_loader.is_valid:
# continue
for sample, examples in data_loader:
agent = SensingAgent(
name="EEG sensing",
model_pool=model_pool,
task_info=data_loader.get_task_info(),
classes_info=data_loader.get_classes_info(),
sensor_info=data_loader.get_sensor_info(),
sample=sample,
examples=examples,
log_path=config["log_path"],
)
ground_truth = sample["label"]
task = asyncio.create_task(agent.solve(sample, examples, ground_truth))
tasks.append(task)
await asyncio.gather(*tasks)
def run(config_path):
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
model_pool = load_models(config["models"])
asyncio.run(run_parallel(config, model_pool))
if __name__ == "__main__":
Fire(run)

16
utils/kill_ollamas.sh Executable file
View File

@@ -0,0 +1,16 @@
#!/bin/bash
echo "Killing all tmux sessions starting with 'ollama'..."
# Get all tmux sessions starting with "ollama"
SESSIONS=$(tmux list-sessions -F '#S' | grep '^ollama')
if [ -z "$SESSIONS" ]; then
echo "No 'ollama' sessions found."
else
for SESSION in $SESSIONS; do
echo "Killing session: $SESSION"
tmux kill-session -t "$SESSION"
done
echo "All 'ollama' sessions have been terminated."
fi

22
utils/launch_ollamas.sh Executable file
View File

@@ -0,0 +1,22 @@
#!/bin/bash
PORTS=(11437 11438 11439 11440 11441 11442 11443 11444)
for i in "${!PORTS[@]}"; do
PORT="${PORTS[$i]}"
SESSION="ollama$i"
# Run a login shell (-l) and then source .bashrc just in case, then start Ollama
CMD="bash -lc 'OLLAMA_MODELS=/mnt/sting/hjyoon/projects/llm/ollama OLLAMA_HOST=0.0.0.0:${PORT} ollama serve'"
if tmux has-session -t "$SESSION" 2>/dev/null; then
echo "Session $SESSION exists. Restarting ollama serve..."
tmux send-keys -t "$SESSION" C-c
sleep 1
tmux send-keys -t "$SESSION" "$CMD" C-m
else
echo "Creating session $SESSION and starting ollama serve on port $PORT..."
tmux new-session -d -s "$SESSION" "$CMD"
fi
done
echo "All Ollama servers have been started or restarted in tmux sessions."