initial codebase for llm-based sensing
This commit is contained in:
127
analysis/analyze_data.ipynb
Normal file
127
analysis/analyze_data.ipynb
Normal file
File diff suppressed because one or more lines are too long
9
config/sleepedf.yaml
Normal file
9
config/sleepedf.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
data_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
num_seeds: 10
|
||||
models:
|
||||
- ollama:url:rose.kaist.ac.kr:11437/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11438/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11439/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11440/gpt-oss:20b
|
||||
- ollama:url:rose.kaist.ac.kr:11441/gpt-oss:20b
|
||||
log_path: "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/logs/SleepEDF"
|
||||
208
core/agent.py
Normal file
208
core/agent.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import tiktoken
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
model_pool,
|
||||
log_path,
|
||||
):
|
||||
self.name = name
|
||||
self.model_pool = model_pool
|
||||
self.log_path = log_path
|
||||
self.root_log_path = log_path
|
||||
self.agent_log_path = os.path.join(log_path, name)
|
||||
os.makedirs(self.agent_log_path, exist_ok=True)
|
||||
|
||||
self.long_term_memory = []
|
||||
self.short_term_memory = []
|
||||
self.volatile_memory = []
|
||||
|
||||
self.total_input_tokens = 0
|
||||
self.total_output_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.total_calls = 0
|
||||
|
||||
def log(self, message, local=True):
|
||||
path = os.path.join(self.root_log_path, "log.txt")
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
message_type = "UNKNOWN"
|
||||
if isinstance(message, SystemMessage):
|
||||
message_type = "SYSTEM"
|
||||
if isinstance(message, HumanMessage):
|
||||
message_type = "HUMAN"
|
||||
if isinstance(message, AIMessage):
|
||||
message_type = "AI"
|
||||
content = message.content.strip()
|
||||
name = self.name
|
||||
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
|
||||
if local:
|
||||
local_path = os.path.join(self.agent_log_path, "log.txt")
|
||||
with open(local_path, "a", encoding="utf-8") as f:
|
||||
f.write(f"[{name}] [{message_type}]\n{content}\n\n\n")
|
||||
|
||||
def log_tokens(self, messages, response):
|
||||
input_tokens = 0
|
||||
for msg in messages:
|
||||
msg_tokens = self.count_tokens(msg.content)
|
||||
input_tokens += msg_tokens
|
||||
self.total_tokens += msg_tokens
|
||||
self.total_input_tokens += msg_tokens
|
||||
output_tokens = self.count_tokens(response.content)
|
||||
self.total_tokens += output_tokens
|
||||
self.total_output_tokens += output_tokens
|
||||
self.total_calls += 1
|
||||
|
||||
path = os.path.join(self.agent_log_path, "tokens.txt")
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(f"Input tokens: {input_tokens}\n")
|
||||
f.write(f"Output tokens: {output_tokens}\n")
|
||||
f.write(f"Total input tokens: {self.total_input_tokens}\n")
|
||||
f.write(f"Total output tokens: {self.total_output_tokens}\n")
|
||||
f.write(f"Total tokens: {self.total_tokens}\n")
|
||||
f.write(f"Total calls: {self.total_calls}\n")
|
||||
f.write("\n")
|
||||
|
||||
def count_tokens(self, text, model="gpt-3.5-turbo"):
|
||||
enc = tiktoken.encoding_for_model(model)
|
||||
return len(enc.encode(text))
|
||||
|
||||
def update_memory(self):
|
||||
self.long_term_memory.extend(self.short_term_memory)
|
||||
self.clean_short_term_memory()
|
||||
self.clean_volatile_memory()
|
||||
|
||||
def clean_short_term_memory(self):
|
||||
self.short_term_memory = []
|
||||
|
||||
def clean_volatile_memory(self):
|
||||
self.volatile_memory = []
|
||||
|
||||
def clean_long_term_memory(self):
|
||||
self.long_term_memory = []
|
||||
|
||||
def clean_json_text(self, text):
|
||||
text = text.strip()
|
||||
text = text.replace("‘", "'").replace("’", "'")
|
||||
text = text.replace("“", "'").replace("”", "'")
|
||||
text = re.sub(r'\\(?!["\\/bfnrtu])', r"\\\\", text)
|
||||
text = re.sub(r",\s*}", "}", text)
|
||||
text = re.sub(r",\s*]", "]", text)
|
||||
text = "".join(ch for ch in text if ch.isprintable())
|
||||
text = text.replace("][", ",")
|
||||
return text
|
||||
|
||||
def safe_parse_json(self, text):
|
||||
text = text.strip()
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
elif not text.endswith("}"):
|
||||
text += "}"
|
||||
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
print("[!] JSON parse failed")
|
||||
return None
|
||||
|
||||
def safe_parse_json_list(self, text):
|
||||
text = text.strip()
|
||||
match = re.search(r"\[.*\]", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
elif not text.endswith("]"):
|
||||
text += "]"
|
||||
match = re.search(r"\[.*\]", text, re.DOTALL)
|
||||
if match:
|
||||
text = match.group(0)
|
||||
text = self.clean_json_text(text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[!] JSON parse failed: {e}")
|
||||
return None
|
||||
print("[!] JSON parse failed")
|
||||
return None
|
||||
|
||||
async def validate_response(self, response, fields, volatile=False):
|
||||
if (
|
||||
not response
|
||||
or not isinstance(response, dict)
|
||||
or not all(field in response for field in fields)
|
||||
):
|
||||
print("[!] The JSON failed to be parsed. Trying again.")
|
||||
content = (
|
||||
"Failed to parse the JSON from the previous response. Please try again."
|
||||
)
|
||||
response = await self.invoke(content, volatile=volatile)
|
||||
response = self.safe_parse_json(response)
|
||||
if (
|
||||
not response
|
||||
or not isinstance(response, dict)
|
||||
or not all(field in response for field in fields)
|
||||
):
|
||||
print("[!] Retry failed.")
|
||||
return None
|
||||
return response
|
||||
|
||||
def get_last_response(self):
|
||||
if len(self.long_term_memory) >= 2:
|
||||
last_msg = self.long_term_memory[-1]
|
||||
if isinstance(last_msg, AIMessage):
|
||||
return self.safe_parse_json(last_msg.content)
|
||||
return None
|
||||
|
||||
def set_system_message(self, content, local=True):
|
||||
system_message = SystemMessage(content=content)
|
||||
self.log(system_message, local)
|
||||
self.long_term_memory.append(system_message)
|
||||
|
||||
async def invoke(self, content, volatile=False, local=True):
|
||||
messages = self.long_term_memory.copy()
|
||||
if volatile:
|
||||
messages.extend(self.volatile_memory)
|
||||
else:
|
||||
messages.extend(self.short_term_memory)
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
try:
|
||||
response = await self.model_pool.invoke(messages)
|
||||
self.log_tokens(messages, response)
|
||||
|
||||
if volatile:
|
||||
self.volatile_memory.extend([HumanMessage(content=content), response])
|
||||
else:
|
||||
self.short_term_memory.extend([HumanMessage(content=content), response])
|
||||
local_ = not volatile and local
|
||||
self.log(HumanMessage(content=content), local=local_)
|
||||
self.log(response, local=local_)
|
||||
|
||||
return response.content.strip()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
print(f"[Error] Error occurred while invoking LLM: {e}")
|
||||
78
core/data_loader.py
Normal file
78
core/data_loader.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import os
|
||||
import json
|
||||
import datasets
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
|
||||
|
||||
class DataLoader:
|
||||
def __init__(self, data_path, user_id, selection_criteria="out_random", num_examples=1):
|
||||
self.is_valid = False
|
||||
if not os.path.exists(os.path.join(data_path, "info.json")):
|
||||
return
|
||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
|
||||
return
|
||||
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
|
||||
return
|
||||
|
||||
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
|
||||
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
|
||||
self.example_dataset = datasets.Dataset.from_list([])
|
||||
users = glob(os.path.join(data_path, "*"))
|
||||
users = [path.split("/")[-1] for path in users]
|
||||
if "info.json" in users:
|
||||
users.remove("info.json")
|
||||
for user in users:
|
||||
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
|
||||
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
|
||||
self.test_dataset = self.test_dataset.shuffle(seed=0)
|
||||
self.example_dataset = self.example_dataset.shuffle(seed=0)
|
||||
self.user_id = user_id
|
||||
self.selection_criteria = selection_criteria
|
||||
self.num_examples = num_examples
|
||||
self.selected_examples = self.sample_examples()
|
||||
self.is_valid = True
|
||||
|
||||
def __len__(self):
|
||||
return len(self.test_dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.test_dataset[idx], self.selected_examples
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.test_dataset:
|
||||
yield sample, self.selected_examples
|
||||
|
||||
def sample_examples(self):
|
||||
classes = self.test_dataset.unique("label")
|
||||
example_dataset = datasets.Dataset.from_list([])
|
||||
if self.selection_criteria == "out_random":
|
||||
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] != user_id)
|
||||
for c in classes:
|
||||
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
|
||||
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
|
||||
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
|
||||
elif self.selection_criteria == "in_random":
|
||||
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] == user_id)
|
||||
for c in classes:
|
||||
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
|
||||
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
|
||||
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
|
||||
return example_dataset
|
||||
|
||||
def get_metadata(self):
|
||||
return self.metadata
|
||||
|
||||
def get_sensor_info(self):
|
||||
return self.metadata["feature"]
|
||||
|
||||
def get_task_info(self):
|
||||
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
|
||||
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
|
||||
classes_info = "\n".join(classes_info)
|
||||
task_info += f"**Classes**:\n{classes_info}"
|
||||
return task_info
|
||||
|
||||
def get_classes_info(self):
|
||||
classes_info = [k for k in self.metadata["class"].keys()]
|
||||
return classes_info
|
||||
95
core/model.py
Normal file
95
core/model.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import asyncio
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
|
||||
def load_models(models):
|
||||
model_pool = AsyncModelPool()
|
||||
for model in models:
|
||||
model_pool.add_model(Model(model))
|
||||
model_pool.init_models()
|
||||
return model_pool
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, model, temperature=0.7):
|
||||
if model.startswith("ollama:"):
|
||||
model = model.replace("ollama:", "")
|
||||
if "url:" in model:
|
||||
model = model.replace("url:", "")
|
||||
base_url = model.split("/")[0]
|
||||
if not base_url.startswith("http"):
|
||||
base_url = "http://" + base_url
|
||||
model_type = model.split("/")[1]
|
||||
self.model = ChatOllama(
|
||||
model=model_type,
|
||||
base_url=base_url,
|
||||
temperature=temperature,
|
||||
num_ctx=12000,
|
||||
)
|
||||
else:
|
||||
self.model = ChatOllama(
|
||||
model=model.replace("ollama:", ""),
|
||||
temperature=temperature,
|
||||
num_ctx=12000,
|
||||
)
|
||||
else:
|
||||
self.model = init_chat_model(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
def invoke(self, messages):
|
||||
try:
|
||||
response = self.model.invoke(messages)
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"[Error] Error occurred while invoking LLM: {e}")
|
||||
return e
|
||||
|
||||
|
||||
class AsyncModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
async def invoke(self, content):
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model.invoke(content),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class AsyncModelPool:
|
||||
def __init__(self):
|
||||
self.models = []
|
||||
self._available_models = None
|
||||
self._model_semaphore = None
|
||||
|
||||
def add_model(self, model):
|
||||
self.models.append(model)
|
||||
|
||||
def init_models(self):
|
||||
# Initialize the queue and semaphore in the current event loop
|
||||
self._available_models = asyncio.Queue()
|
||||
for model in self.models:
|
||||
async_model = AsyncModel(model)
|
||||
self._available_models.put_nowait(async_model)
|
||||
self._model_semaphore = asyncio.Semaphore(len(self.models))
|
||||
|
||||
# Test each model
|
||||
for model in self.models:
|
||||
model.invoke([HumanMessage(content="Hello world!")])
|
||||
|
||||
async def invoke(self, content):
|
||||
if self._available_models is None:
|
||||
raise RuntimeError("Model pool not initialized. Call init_models() first.")
|
||||
async_model = await self._available_models.get()
|
||||
try:
|
||||
response = await async_model.invoke(content)
|
||||
return response
|
||||
finally:
|
||||
self._available_models.put_nowait(async_model)
|
||||
186
core/sensing_agent.py
Normal file
186
core/sensing_agent.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
import copy
|
||||
import os
|
||||
from .agent import Agent
|
||||
|
||||
|
||||
class SensingAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
model_pool,
|
||||
task_info,
|
||||
classes_info,
|
||||
sensor_info,
|
||||
sample,
|
||||
examples,
|
||||
log_path,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
model_pool=model_pool,
|
||||
log_path=log_path,
|
||||
)
|
||||
|
||||
self.task_info = task_info
|
||||
self.classes_info = classes_info
|
||||
self.sensor_info = sensor_info
|
||||
self.sample = sample
|
||||
self.examples = examples
|
||||
self.init_system_message()
|
||||
|
||||
def init_system_message(self):
|
||||
content = (
|
||||
f"You are {self.name} agent that interprets sensor data to solve a task.\n"
|
||||
"You have the following information about the task:\n"
|
||||
f"{self.task_info}\n\n"
|
||||
"You have the following information about the sensor data:\n"
|
||||
f"{self.sensor_info}\n\n"
|
||||
"Your goal is to analyze the features and "
|
||||
"provide a reasoned answer using your knowledge."
|
||||
)
|
||||
self.set_system_message(content)
|
||||
|
||||
def gen_feature_info(self):
|
||||
feature_info = f"{self.name} features:\n"
|
||||
if len(self.examples) > 0:
|
||||
feature_info += f"{self.gen_example_info()}\n\n"
|
||||
feature_info += "**Current sample features**:\n"
|
||||
for k, v in self.sample["features"].items():
|
||||
# print(k, v)
|
||||
# assert 0
|
||||
# modalities = self.name.split("+")
|
||||
# for modality in modalities:
|
||||
# if modality in k:
|
||||
feature_info += f" - {k}: {self.format_feature(v)}\n"
|
||||
feature_info = feature_info.strip()
|
||||
return feature_info
|
||||
|
||||
def gen_example_info(self):
|
||||
example_info = (
|
||||
"**Examples**\n"
|
||||
"Sensor values might not always align with your inherent "
|
||||
"knowledge due to differences in data collection or processing. "
|
||||
"So, we included a few labeled examples to help your interpretation:\n"
|
||||
)
|
||||
for example in self.examples:
|
||||
example_info += f"*Example of {example['label']}*:\n"
|
||||
for k, v in example["features"].items():
|
||||
# modalities = self.name.split("+")
|
||||
# for modality in modalities:
|
||||
# if modality in k:
|
||||
example_info += f" - {k}: {self.format_feature(v)}\n"
|
||||
example_info += "\n"
|
||||
example_info = example_info.strip()
|
||||
return example_info
|
||||
|
||||
def format_feature(self, value):
|
||||
if isinstance(value, float):
|
||||
if abs(value) >= 1e4 or abs(value) < 1e-2:
|
||||
return f"{value:.2e}"
|
||||
return f"{value:.2f}"
|
||||
return value
|
||||
|
||||
|
||||
async def solve(self, sample, examples, ground_truth):
|
||||
self.sample = sample
|
||||
self.examples = examples
|
||||
feature_info = self.gen_feature_info()
|
||||
content = (
|
||||
f"You have received sensor features from {self.name} modality:\n"
|
||||
f"{feature_info}\n\n"
|
||||
f"Please provide your answer for the task among {self.classes_info} "
|
||||
"and the reasoning for your answer.\n"
|
||||
"Note that the sensor features might be wrong due to the data collection or processing.\n"
|
||||
"You can evaluate the quality of the features by checking the examples you have.\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
' "REASON": "<Reasoning for the answer>",\n'
|
||||
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
|
||||
"}\n\n"
|
||||
"Do not include any additional text outside the JSON."
|
||||
)
|
||||
|
||||
response = await self.invoke(content)
|
||||
parsed_response = self.safe_parse_json(response)
|
||||
parsed_response = await self.validate_response(
|
||||
parsed_response, ["REASON", "ANSWER"]
|
||||
)
|
||||
self.clean_short_term_memory()
|
||||
self.clean_long_term_memory()
|
||||
answer = parsed_response["ANSWER"]
|
||||
if answer == ground_truth:
|
||||
print(f"Correct answer: {answer}")
|
||||
else:
|
||||
print(f"Incorrect answer: {answer} (Ground truth: {ground_truth})")
|
||||
return parsed_response
|
||||
|
||||
async def interpret(self):
|
||||
feature_info = self.gen_feature_info()
|
||||
content = (
|
||||
f"You have received sensor features from {self.name} modality:\n"
|
||||
f"{feature_info}\n\n"
|
||||
f"Please provide your answer for the task among {self.classes_info} "
|
||||
"and the reasoning for your answer.\n"
|
||||
"Note that the sensor features might be wrong due to the data collection or processing.\n"
|
||||
"You can evaluate the quality of the features by checking the examples you have.\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
' "REASON": "<Reasoning for the answer>",\n'
|
||||
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
|
||||
"}\n\n"
|
||||
"Do not include any additional text outside the JSON."
|
||||
)
|
||||
|
||||
response = await self.invoke(content)
|
||||
parsed_response = self.safe_parse_json(response)
|
||||
parsed_response = await self.validate_response(
|
||||
parsed_response, ["REASON", "ANSWER"]
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
async def evaluate(self, target_name, initial_response):
|
||||
initial_response_info = json.dumps(initial_response, indent=2)
|
||||
content = (
|
||||
f"Other agent, <{target_name}> provided the following answer for the same task:\n"
|
||||
f"{initial_response_info}\n\n"
|
||||
"Please evaluate the given reasoning and answer based on your judgement. "
|
||||
"You may either support with it or disagree.\n"
|
||||
"If you agree, explain why the reasoning and answer are valid. "
|
||||
"If you disagree, explain why the reasoning or answer may be flawed, "
|
||||
f"and provide constructive feedback on how <{target_name}> can improve its response.\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
f' "EVALUATION": "<Evaluation to <{target_name}>"\n'
|
||||
"}\n\n"
|
||||
"Do not include any additional text outside the JSON."
|
||||
)
|
||||
|
||||
response = await self.invoke(content, volatile=True)
|
||||
parsed_response = self.safe_parse_json(response)
|
||||
parsed_response = await self.validate_response(
|
||||
parsed_response, ["EVALUATION"], volatile=True
|
||||
)
|
||||
self.clean_volatile_memory()
|
||||
return parsed_response
|
||||
|
||||
async def reflect(self, evaluations):
|
||||
evaluations_info = json.dumps(evaluations, indent=2)
|
||||
content = (
|
||||
f"Other agents have evaluated your answer for the same task:\n"
|
||||
f"{evaluations_info}\n\n"
|
||||
"Please reflect on the evaluations and provide a refined answer for the same task.\n"
|
||||
"Respond in the following strict JSON format:\n"
|
||||
"{\n"
|
||||
' "REASON": "<Reasoning for the answer>",\n'
|
||||
f' "ANSWER": "<Answer among {self.classes_info}>",\n'
|
||||
"}\n\n"
|
||||
"Do not include any additional text outside the JSON."
|
||||
)
|
||||
|
||||
response = await self.invoke(content)
|
||||
parsed_response = self.safe_parse_json(response)
|
||||
parsed_response = await self.validate_response(
|
||||
parsed_response, ["REASON", "ANSWER"]
|
||||
)
|
||||
return parsed_response
|
||||
144
preprocess/dhedfreader.py
Normal file
144
preprocess/dhedfreader.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Credits to: https://github.com/emadeldeen24/TS-TCC/blob/main/data_preprocessing/sleep-edf/dhedfreader.py
|
||||
"""
|
||||
|
||||
import re, datetime, operator, logging, sys
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
|
||||
EVENT_CHANNEL = "EDF Annotations"
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EDFEndOfData(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def tal(tal_str):
|
||||
"""Return a list with (onset, duration, annotation) tuples for an EDF+ TAL
|
||||
stream.
|
||||
"""
|
||||
exp = (
|
||||
"(?P<onset>[+\-]\d+(?:\.\d*)?)"
|
||||
+ "(?:\x15(?P<duration>\d+(?:\.\d*)?))?"
|
||||
+ "(\x14(?P<annotation>[^\x00]*))?"
|
||||
+ "(?:\x14\x00)"
|
||||
)
|
||||
|
||||
def annotation_to_list(annotation):
|
||||
return str(annotation.encode("utf-8")).split("\x14") if annotation else []
|
||||
|
||||
def parse(dic):
|
||||
return (
|
||||
float(dic["onset"]),
|
||||
float(dic["duration"]) if dic["duration"] else 0.0,
|
||||
annotation_to_list(dic["annotation"]),
|
||||
)
|
||||
|
||||
return [parse(m.groupdict()) for m in re.finditer(exp, tal_str)]
|
||||
|
||||
|
||||
def edf_header(f):
|
||||
h = {}
|
||||
assert f.tell() == 0 # check file position
|
||||
assert f.read(8) == "0 "
|
||||
|
||||
# recording info)
|
||||
h["local_subject_id"] = f.read(80).strip()
|
||||
h["local_recording_id"] = f.read(80).strip()
|
||||
|
||||
# parse timestamp
|
||||
(day, month, year) = [int(x) for x in re.findall("(\d+)", f.read(8))]
|
||||
(hour, minute, sec) = [int(x) for x in re.findall("(\d+)", f.read(8))]
|
||||
h["date_time"] = str(datetime.datetime(year + 2000, month, day, hour, minute, sec))
|
||||
|
||||
# misc
|
||||
header_nbytes = int(f.read(8))
|
||||
subtype = f.read(44)[:5]
|
||||
h["EDF+"] = subtype in ["EDF+C", "EDF+D"]
|
||||
h["contiguous"] = subtype != "EDF+D"
|
||||
h["n_records"] = int(f.read(8))
|
||||
h["record_length"] = float(f.read(8)) # in seconds
|
||||
nchannels = h["n_channels"] = int(f.read(4))
|
||||
|
||||
# read channel info
|
||||
channels = range(h["n_channels"])
|
||||
h["label"] = [f.read(16).strip() for n in channels]
|
||||
h["transducer_type"] = [f.read(80).strip() for n in channels]
|
||||
h["units"] = [f.read(8).strip() for n in channels]
|
||||
h["physical_min"] = np.asarray([float(f.read(8)) for n in channels])
|
||||
h["physical_max"] = np.asarray([float(f.read(8)) for n in channels])
|
||||
h["digital_min"] = np.asarray([float(f.read(8)) for n in channels])
|
||||
h["digital_max"] = np.asarray([float(f.read(8)) for n in channels])
|
||||
h["prefiltering"] = [f.read(80).strip() for n in channels]
|
||||
h["n_samples_per_record"] = [int(f.read(8)) for n in channels]
|
||||
f.read(32 * nchannels) # reserved
|
||||
|
||||
# assert f.tell() == header_nbytes
|
||||
return h
|
||||
|
||||
|
||||
class BaseEDFReader:
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def read_header(self):
|
||||
self.header = h = edf_header(self.file)
|
||||
|
||||
# calculate ranges for rescaling
|
||||
self.dig_min = h["digital_min"]
|
||||
self.phys_min = h["physical_min"]
|
||||
phys_range = h["physical_max"] - h["physical_min"]
|
||||
dig_range = h["digital_max"] - h["digital_min"]
|
||||
# assert np.all(phys_range > 0)
|
||||
# assert np.all(dig_range > 0)
|
||||
self.gain = phys_range / dig_range
|
||||
|
||||
def read_raw_record(self):
|
||||
"""Read a record with data_2013 and return a list containing arrays with raw
|
||||
bytes.
|
||||
"""
|
||||
result = []
|
||||
for nsamp in self.header["n_samples_per_record"]:
|
||||
samples = self.file.read(nsamp * 2)
|
||||
if len(samples) != nsamp * 2:
|
||||
raise EDFEndOfData
|
||||
result.append(samples)
|
||||
return result
|
||||
|
||||
def convert_record(self, raw_record):
|
||||
"""Convert a raw record to a (time, signals, events) tuple based on
|
||||
information in the header.
|
||||
"""
|
||||
h = self.header
|
||||
dig_min, phys_min, gain = self.dig_min, self.phys_min, self.gain
|
||||
time = float("nan")
|
||||
signals = []
|
||||
events = []
|
||||
for i, samples in enumerate(raw_record):
|
||||
if h["label"][i] == EVENT_CHANNEL:
|
||||
ann = tal(samples)
|
||||
time = ann[0][0]
|
||||
events.extend(ann[1:])
|
||||
# print(i, samples)
|
||||
# exit()
|
||||
else:
|
||||
# 2-byte little-endian integers
|
||||
dig = np.fromstring(samples, "<i2").astype(np.float32)
|
||||
phys = (dig - dig_min[i]) * gain[i] + phys_min[i]
|
||||
signals.append(phys)
|
||||
|
||||
return time, signals, events
|
||||
|
||||
def read_record(self):
|
||||
return self.convert_record(self.read_raw_record())
|
||||
|
||||
def records(self):
|
||||
"""
|
||||
Record generator.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
yield self.read_record()
|
||||
except EDFEndOfData:
|
||||
pass
|
||||
479
preprocess/preprocess_SleepEDF.py
Normal file
479
preprocess/preprocess_SleepEDF.py
Normal file
@@ -0,0 +1,479 @@
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import warnings
|
||||
import numpy as np
|
||||
import neurokit2 as nk
|
||||
|
||||
from tqdm import tqdm
|
||||
from fire import Fire
|
||||
from glob import glob
|
||||
from mne.io import read_raw_edf # pylint: disable=no-name-in-module
|
||||
from scipy import signal
|
||||
from datasets import Dataset, concatenate_datasets
|
||||
from datetime import datetime
|
||||
from multiprocessing import Pool
|
||||
from dhedfreader import BaseEDFReader
|
||||
|
||||
from neurokit2.misc._warnings import NeuroKitWarning
|
||||
|
||||
warnings.simplefilter("ignore", NeuroKitWarning)
|
||||
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
|
||||
|
||||
SLEEPEDF_PATH = "/mnt/sting/hjyoon/projects/bymyeyes/dataset/SleepEDF/raw/sleep-edf-database-expanded-1.0.0/sleep-cassette/"
|
||||
OUT_DIR = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
EPOCH_SEC_SIZE = 30
|
||||
SAMPLING_RATE = 100
|
||||
|
||||
W = 0
|
||||
N1 = 1
|
||||
N2 = 2
|
||||
N3 = 3
|
||||
REM = 4
|
||||
UNKNOWN = 5
|
||||
|
||||
class_dict = {0: "W", 1: "N1", 2: "N2", 3: "N3", 4: "REM", 5: "UNKNOWN"}
|
||||
ann2label = {
|
||||
"Sleep stage W": 0,
|
||||
"Sleep stage 1": 1,
|
||||
"Sleep stage 2": 2,
|
||||
"Sleep stage 3": 3,
|
||||
"Sleep stage 4": 3,
|
||||
"Sleep stage R": 4,
|
||||
"Sleep stage ?": 5,
|
||||
"Movement time": 5,
|
||||
}
|
||||
|
||||
|
||||
def store_info(info_path):
|
||||
info = {
|
||||
"task": 'Classify the user\'s sleep stage: ["W", "N1", "N2", "N3", "REM"], based on physiological signals collected from wearable sensors.',
|
||||
"class": {
|
||||
"W": "Wakefulness. This includes periods before sleep onset or after final awakening, and short awakenings during the night.",
|
||||
"N1": "Stage N1 (light sleep). This is the transition stage between wakefulness and deeper sleep.",
|
||||
"N2": "Stage N2 (intermediate sleep). This stage typically follows N1 and occurs multiple times throughout the night.",
|
||||
"N3": "Stage N3 (deep sleep). This stage is associated with the most restful sleep and occurs mostly in the first half of the night.",
|
||||
"REM": "Rapid Eye Movement (REM) sleep. This stage is associated with dreaming and typically occurs cyclically later in the night.",
|
||||
},
|
||||
"data": (
|
||||
"Data were recorded during overnight sleep using multiple sensors. \n"
|
||||
"EEG signals were recorded from two electrode pairs, Fpz-Cz and Pz-Oz. "
|
||||
"Both EEG signals (in micro volt) were sampled at 100 Hz.\n"
|
||||
"All signals were time-synchronized, and features were computed using a fixed 30-second window.\n"
|
||||
"Each feature is named using the format 'modality_featurename', where modality refers to the type of signal "
|
||||
"(e.g., EEG-Fpz-Cz, EEG-Pz-Oz) and featurename describes the extracted characteristic (e.g., delta_power)."
|
||||
),
|
||||
"feature": (
|
||||
"For EEG signals, a bandpass filter was applied to extract key frequency components: delta (0.5-4 Hz), "
|
||||
"theta (4-8 Hz), alpha (8-12 Hz), beta (12-30 Hz), spindle (12-14 Hz), k-complex (0.5-1.5 Hz), and sawtooth (2-6 Hz). "
|
||||
"For each band, time-domain features such as mean, standard deviation (std), variance, dynamic range, number of peaks, "
|
||||
"number of zero-crossings, and variance of the first-order difference were extracted.\n"
|
||||
"Additionally, power spectral density was estimated using Welch's method to compute absolute power for each band. "
|
||||
"Ratio features such as delta/theta, theta/alpha, alpha/beta, and (delta+theta)/(alpha+beta) were also included."
|
||||
),
|
||||
}
|
||||
with open(info_path, "w", encoding="utf-8") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
|
||||
def lowpass_filter(data, cutoff=50, fs=1000, order=4):
|
||||
nyq = 0.5 * fs
|
||||
normal_cutoff = cutoff / nyq
|
||||
b, a = signal.butter(order, normal_cutoff, btype="low", analog=False)
|
||||
return signal.filtfilt(b, a, data)
|
||||
|
||||
|
||||
def count_large_eye_movements(
|
||||
eog_cleaned,
|
||||
sr,
|
||||
amp_thresh=120,
|
||||
time_thresh=1.5,
|
||||
):
|
||||
time_limit = int(time_thresh * sr)
|
||||
peaks, _ = signal.find_peaks(eog_cleaned)
|
||||
troughs, _ = signal.find_peaks(-eog_cleaned)
|
||||
events = np.sort(np.concatenate((peaks, troughs)))
|
||||
count = 0
|
||||
for i in range(len(events) - 1):
|
||||
t1, t2 = events[i], events[i + 1]
|
||||
if abs(t2 - t1) <= time_limit:
|
||||
amp = abs(eog_cleaned[t2] - eog_cleaned[t1])
|
||||
if amp >= amp_thresh:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def remove_large_eye_movements(
|
||||
eog_cleaned,
|
||||
fs=100,
|
||||
amp_thresh=120,
|
||||
time_thresh=1.5,
|
||||
pad=0.75,
|
||||
):
|
||||
time_limit = int(time_thresh * fs)
|
||||
pad_samples = int(pad * fs)
|
||||
peaks, _ = signal.find_peaks(eog_cleaned)
|
||||
troughs, _ = signal.find_peaks(-eog_cleaned)
|
||||
events = np.sort(np.concatenate((peaks, troughs)))
|
||||
mask = np.ones_like(eog_cleaned, dtype=bool)
|
||||
for i in range(len(events) - 1):
|
||||
t1, t2 = events[i], events[i + 1]
|
||||
if abs(t2 - t1) <= time_limit:
|
||||
amp = abs(eog_cleaned[t2] - eog_cleaned[t1])
|
||||
if amp >= amp_thresh:
|
||||
start = max(0, min(t1, t2) - pad_samples)
|
||||
end = min(len(eog_cleaned), max(t1, t2) + pad_samples)
|
||||
mask[start:end] = False
|
||||
cleaned_signal = eog_cleaned.copy()
|
||||
cleaned_signal[~mask] = np.mean(eog_cleaned[mask])
|
||||
return cleaned_signal
|
||||
|
||||
|
||||
def bandpass_filter(data, fs, band, order=4):
|
||||
nyq = 0.5 * fs
|
||||
low = band[0] / nyq
|
||||
high = band[1] / nyq
|
||||
b, a = signal.butter(order, [low, high], btype="band")
|
||||
return signal.filtfilt(b, a, data)
|
||||
|
||||
|
||||
def process_by_mod(modality, data):
|
||||
sr = SAMPLING_RATE
|
||||
features = {}
|
||||
modality = "-".join(modality.split())
|
||||
if "EEG" in modality:
|
||||
band_data = data
|
||||
signal_modality = modality
|
||||
mean_val = np.mean(band_data)
|
||||
std_val = np.std(band_data)
|
||||
var_val = np.var(band_data)
|
||||
features[f"{signal_modality}_mean"] = mean_val
|
||||
features[f"{signal_modality}_std"] = std_val
|
||||
features[f"{signal_modality}_variance"] = var_val
|
||||
dynamic_range = np.max(band_data) - np.min(band_data)
|
||||
features[f"{signal_modality}_dynamic_range"] = dynamic_range
|
||||
peaks = signal.find_peaks(band_data - mean_val, height=3 * std_val)[0]
|
||||
features[f"{signal_modality}_num_peaks"] = len(peaks)
|
||||
zero_crossings = np.where(np.diff(np.sign(band_data - mean_val)))[0]
|
||||
features[f"{signal_modality}_num_zero_crossings"] = len(zero_crossings)
|
||||
differences = band_data[1:] - band_data[:-1]
|
||||
difference_variance = np.var(differences)
|
||||
features[f"{signal_modality}_difference_variance"] = difference_variance
|
||||
|
||||
freqs, psd = signal.welch(data, fs=sr, nperseg=sr * 2)
|
||||
delta_idx = np.logical_and(freqs >= 0.5, freqs <= 4)
|
||||
theta_idx = np.logical_and(freqs >= 4, freqs <= 8)
|
||||
alpha_idx = np.logical_and(freqs >= 8, freqs <= 12)
|
||||
beta_idx = np.logical_and(freqs >= 12, freqs <= 30)
|
||||
spindle_idx = np.logical_and(freqs >= 12, freqs <= 14)
|
||||
kcomplex_idx = np.logical_and(freqs >= 0.5, freqs <= 1.5)
|
||||
sawtooth_idx = np.logical_and(freqs >= 2, freqs <= 6)
|
||||
|
||||
delta_power = np.trapezoid(psd[delta_idx], freqs[delta_idx])
|
||||
theta_power = np.trapezoid(psd[theta_idx], freqs[theta_idx])
|
||||
alpha_power = np.trapezoid(psd[alpha_idx], freqs[alpha_idx])
|
||||
beta_power = np.trapezoid(psd[beta_idx], freqs[beta_idx])
|
||||
spindle_power = np.trapezoid(psd[spindle_idx], freqs[spindle_idx])
|
||||
kcomplex_power = np.trapezoid(psd[kcomplex_idx], freqs[kcomplex_idx])
|
||||
sawtooth_power = np.trapezoid(psd[sawtooth_idx], freqs[sawtooth_idx])
|
||||
|
||||
delta_theta_ratio = delta_power / theta_power if theta_power > 0 else 0
|
||||
theta_alpha_ratio = theta_power / alpha_power if alpha_power > 0 else 0
|
||||
alpha_beta_ratio = alpha_power / beta_power if beta_power > 0 else 0
|
||||
slow_fast_ratio = (
|
||||
(delta_power + theta_power) / (alpha_power + beta_power)
|
||||
if (alpha_power + beta_power) > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
features[f"{modality}_delta_power"] = delta_power
|
||||
features[f"{modality}_theta_power"] = theta_power
|
||||
features[f"{modality}_alpha_power"] = alpha_power
|
||||
features[f"{modality}_beta_power"] = beta_power
|
||||
features[f"{modality}_spindle_power"] = spindle_power
|
||||
features[f"{modality}_kcomplex_power"] = kcomplex_power
|
||||
features[f"{modality}_sawtooth_power"] = sawtooth_power
|
||||
features[f"{modality}_delta/theta_ratio"] = delta_theta_ratio
|
||||
features[f"{modality}_theta/alpha_ratio"] = theta_alpha_ratio
|
||||
features[f"{modality}_alpha/beta_ratio"] = alpha_beta_ratio
|
||||
features[f"{modality}_(delta+theta)/(alpha+beta)_ratio"] = slow_fast_ratio
|
||||
# elif "EOG" in modality:
|
||||
# eog_cleaned = data
|
||||
# try:
|
||||
# eog_cleaned = nk.eog_clean(data, sr)
|
||||
# except IndexError as e:
|
||||
# print(f"Error processing EOG data for {modality}: {e}")
|
||||
# return None
|
||||
# eog_mean = np.mean(eog_cleaned)
|
||||
# eog_std = np.std(eog_cleaned)
|
||||
# eog_var = np.var(eog_cleaned)
|
||||
# features[f"{modality}_mean"] = eog_mean
|
||||
# features[f"{modality}_std"] = eog_std
|
||||
# features[f"{modality}_variance"] = eog_var
|
||||
# dynamic_range = np.max(eog_cleaned) - np.min(eog_cleaned)
|
||||
# features[f"{modality}_dynamic_range"] = dynamic_range
|
||||
# peaks = signal.find_peaks(eog_cleaned - eog_mean, height=3 * eog_std)[0]
|
||||
# features[f"{modality}_num_peaks"] = len(peaks)
|
||||
# zero_crossings = np.where(np.diff(np.sign(eog_cleaned - eog_mean)))[0]
|
||||
# features[f"{modality}_num_zero_crossings"] = len(zero_crossings)
|
||||
# differences = eog_cleaned[1:] - eog_cleaned[:-1]
|
||||
# difference_variance = np.var(differences)
|
||||
# features[f"{modality}_difference_variance"] = difference_variance
|
||||
# features[f"{modality}_num_large_eye_movements"] = count_large_eye_movements(
|
||||
# eog_cleaned, sr, amp_thresh=120, time_thresh=1.5
|
||||
# )
|
||||
# eog_large_movement_removed = remove_large_eye_movements(
|
||||
# eog_cleaned, fs=sr, amp_thresh=120, time_thresh=1.5, pad=0.75
|
||||
# )
|
||||
# differences = eog_large_movement_removed[1:] - eog_large_movement_removed[:-1]
|
||||
# difference_variance = np.var(differences)
|
||||
# features[f"{modality}_difference_variance_without_large_movements"] = (
|
||||
# difference_variance
|
||||
# )
|
||||
# freqs, psd = signal.welch(eog_cleaned, fs=sr, nperseg=sr * 2)
|
||||
# total_idx = np.logical_and(freqs >= 0.5, freqs <= 30)
|
||||
# total_power = np.trapezoid(psd[total_idx], freqs[total_idx])
|
||||
# slow_idx = np.logical_and(freqs >= 0.5, freqs <= 2)
|
||||
# rapid_idx = np.logical_and(freqs >= 2, freqs <= 5)
|
||||
# slow_power = np.trapezoid(psd[slow_idx], freqs[slow_idx])
|
||||
# rapid_power = np.trapezoid(psd[rapid_idx], freqs[rapid_idx])
|
||||
# slow_power_ratio = slow_power / total_power if total_power > 0 else 0
|
||||
# rapid_power_ratio = rapid_power / total_power if total_power > 0 else 0
|
||||
# features[f"{modality}_slow_movement_power_ratio"] = slow_power_ratio
|
||||
# features[f"{modality}_rapid_movement_power_ratio"] = rapid_power_ratio
|
||||
# elif "Resp" in modality:
|
||||
# rsp_signals = data
|
||||
# try:
|
||||
# rsp_signals, _ = nk.rsp_process(data, sampling_rate=sr, method="biosppy")
|
||||
# except IndexError as e:
|
||||
# print(f"Error processing respiration data for {modality}: {e}")
|
||||
# return None
|
||||
# clean = rsp_signals["RSP_Clean"]
|
||||
# phase = rsp_signals["RSP_Phase"]
|
||||
# rate = rsp_signals["RSP_Rate"]
|
||||
# amplitude = rsp_signals["RSP_Amplitude"]
|
||||
# peaks = np.where(rsp_signals["RSP_Peaks"] == 1)[0]
|
||||
# troughs = np.where(rsp_signals["RSP_Troughs"] == 1)[0]
|
||||
# inhale_durations = []
|
||||
# for t in troughs:
|
||||
# next_peaks = peaks[peaks > t]
|
||||
# if len(next_peaks) == 0:
|
||||
# continue
|
||||
# inhale_durations.append((next_peaks[0] - t) / sr)
|
||||
# inhale_durations = np.array(inhale_durations)
|
||||
# exhale_durations = []
|
||||
# for p in peaks:
|
||||
# next_troughs = troughs[troughs > p]
|
||||
# if len(next_troughs) == 0:
|
||||
# continue
|
||||
# exhale_durations.append((next_troughs[0] - p) / sr)
|
||||
# exhale_durations = np.array(exhale_durations)
|
||||
# features[f"{modality}_inhale_duration_mean"] = np.mean(inhale_durations)
|
||||
# features[f"{modality}_inhale_duration_std"] = np.std(inhale_durations)
|
||||
# features[f"{modality}_exhale_duration_mean"] = np.mean(exhale_durations)
|
||||
# features[f"{modality}_exhale_duration_std"] = np.std(exhale_durations)
|
||||
# features[f"{modality}_inhale_exhale_ratio"] = (
|
||||
# np.mean(inhale_durations) / np.mean(exhale_durations)
|
||||
# if np.mean(exhale_durations) > 0
|
||||
# else np.nan
|
||||
# )
|
||||
# features[f"{modality}_stretch"] = np.max(clean) - np.min(clean)
|
||||
# inhale_mask = phase == 1
|
||||
# features[f"{modality}_inspiration_volume"] = np.trapezoid(
|
||||
# amplitude[inhale_mask], dx=1 / sr
|
||||
# )
|
||||
# features[f"{modality}_respiration_rate"] = np.mean(rate)
|
||||
# resp_durations = np.diff(troughs) / sr
|
||||
# features[f"{modality}_respiration_duration"] = np.mean(resp_durations)
|
||||
# elif "EMG" in modality:
|
||||
# emg_mean = np.mean(data)
|
||||
# emg_std = np.std(data)
|
||||
# features[f"{modality}_mean"] = emg_mean
|
||||
# features[f"{modality}_std"] = emg_std
|
||||
# features[f"{modality}_dynamic_range"] = np.max(data) - np.min(data)
|
||||
# features[f"{modality}_absolute_integral"] = np.sum(np.abs(data)) / sr
|
||||
# features[f"{modality}_median"] = np.median(data)
|
||||
# features[f"{modality}_10th_percentile"] = np.percentile(data, 10)
|
||||
# features[f"{modality}_90th_percentile"] = np.percentile(data, 90)
|
||||
|
||||
# peaks, _ = signal.find_peaks(data, height=3 * emg_std)
|
||||
# peak_values = data[peaks]
|
||||
# features[f"{modality}_num_peaks"] = len(peaks)
|
||||
# features[f"{modality}_peak_amplitude_mean"] = (
|
||||
# np.mean(peak_values) if len(peak_values) > 0 else 0
|
||||
# )
|
||||
# features[f"{modality}_peak_amplitude_std"] = (
|
||||
# np.std(peak_values) if len(peak_values) > 0 else 0
|
||||
# )
|
||||
# features[f"{modality}_peak_amplitude_sum"] = (
|
||||
# np.sum(peak_values) if len(peak_values) > 0 else 0
|
||||
# )
|
||||
# features[f"{modality}_peak_amplitude_norm_sum"] = (
|
||||
# np.sum(peak_values) / np.sum(np.abs(data))
|
||||
# if np.sum(np.abs(data)) > 0
|
||||
# else 0
|
||||
# )
|
||||
return features
|
||||
|
||||
|
||||
def preprocess(file_path):
|
||||
psg_file_path = file_path[0]
|
||||
ann_file_path = file_path[1]
|
||||
raw = read_raw_edf(psg_file_path, preload=True, stim_channel=None)
|
||||
sr = raw.info["sfreq"]
|
||||
curr_user_id = psg_file_path.split("/")[-1].split("-")[0][3:5]
|
||||
curr_session_id = psg_file_path.split("/")[-1].split("-")[0][5]
|
||||
|
||||
raw_ch_df = raw.to_data_frame()
|
||||
raw_ch_df.set_index(np.arange(len(raw_ch_df)))
|
||||
raw_ch_df.drop(columns=["time", "Temp rectal", "Event marker"], inplace=True)
|
||||
column_list = raw_ch_df.columns.tolist()
|
||||
column_list = [c.replace("_", "-") for c in column_list]
|
||||
|
||||
f = open(psg_file_path, "r", encoding="iso-8859-1")
|
||||
reader_raw = BaseEDFReader(f)
|
||||
reader_raw.read_header()
|
||||
h_raw = reader_raw.header
|
||||
f.close()
|
||||
raw_start_dt = datetime.strptime(h_raw["date_time"], "%Y-%m-%d %H:%M:%S")
|
||||
f = open(ann_file_path, "r", encoding="iso-8859-1")
|
||||
reader_ann = BaseEDFReader(f)
|
||||
reader_ann.read_header()
|
||||
h_ann = reader_ann.header
|
||||
_, _, ann = list(zip(*reader_ann.records()))
|
||||
f.close()
|
||||
ann_start_dt = datetime.strptime(h_ann["date_time"], "%Y-%m-%d %H:%M:%S")
|
||||
assert raw_start_dt == ann_start_dt
|
||||
|
||||
remove_idx = []
|
||||
labels = []
|
||||
label_idx = []
|
||||
for a in ann[0]:
|
||||
onset_sec, duration_sec, ann_char = a
|
||||
ann_str = "".join(ann_char)
|
||||
ann_str = ann_str.strip("b'\"")
|
||||
label = ann2label[ann_str]
|
||||
if label != UNKNOWN:
|
||||
if duration_sec % EPOCH_SEC_SIZE != 0:
|
||||
assert False, "Duration should be multiple of 30 seconds"
|
||||
duration_epoch = int(duration_sec / EPOCH_SEC_SIZE)
|
||||
label_epoch = np.ones(duration_epoch, dtype=int) * label
|
||||
labels.append(label_epoch)
|
||||
idx = int(onset_sec * sr) + np.arange(duration_sec * sr, dtype=int)
|
||||
label_idx.append(idx)
|
||||
else:
|
||||
idx = int(onset_sec * sr) + np.arange(duration_sec * sr, dtype=int)
|
||||
remove_idx.append(idx)
|
||||
labels = np.hstack(labels)
|
||||
|
||||
if len(remove_idx) > 0:
|
||||
remove_idx = np.hstack(remove_idx)
|
||||
select_idx = np.setdiff1d(np.arange(len(raw_ch_df)), remove_idx)
|
||||
else:
|
||||
select_idx = np.arange(len(raw_ch_df))
|
||||
label_idx = np.hstack(label_idx)
|
||||
select_idx = np.intersect1d(select_idx, label_idx)
|
||||
|
||||
if len(label_idx) > len(select_idx):
|
||||
extra_idx = np.setdiff1d(label_idx, select_idx)
|
||||
if np.all(extra_idx > select_idx[-1]):
|
||||
n_trims = len(select_idx) % int(EPOCH_SEC_SIZE * sr)
|
||||
n_label_trims = int(math.ceil(n_trims / (EPOCH_SEC_SIZE * sr)))
|
||||
select_idx = select_idx[:-n_trims]
|
||||
labels = labels[:-n_label_trims]
|
||||
if len(labels) == 0:
|
||||
print("No labels left after removing extra labels.")
|
||||
return []
|
||||
|
||||
raw_ch = raw_ch_df.values[select_idx]
|
||||
if len(raw_ch) % (EPOCH_SEC_SIZE * sr) != 0:
|
||||
assert False, "Data length is not multiple of 30 seconds."
|
||||
n_epochs = len(raw_ch) / (EPOCH_SEC_SIZE * sr)
|
||||
|
||||
x = np.asarray(np.split(raw_ch, n_epochs)).astype(np.float32)
|
||||
y = labels.astype(np.int32)
|
||||
assert len(x) == len(y)
|
||||
|
||||
w_edge_mins = 30
|
||||
nw_idx = np.where(y != W)[0]
|
||||
start_idx = nw_idx[0] - (w_edge_mins * 2)
|
||||
end_idx = nw_idx[-1] + (w_edge_mins * 2)
|
||||
if start_idx < 0:
|
||||
start_idx = 0
|
||||
if end_idx >= len(y):
|
||||
end_idx = len(y) - 1
|
||||
select_idx = np.arange(start_idx, end_idx + 1)
|
||||
x = x[select_idx]
|
||||
y = y[select_idx]
|
||||
|
||||
num_samples = x.shape[0]
|
||||
num_modalities = x.shape[2]
|
||||
data = []
|
||||
pbar = tqdm(total=5 * num_samples, desc=f"Processing {curr_user_id}", leave=False)
|
||||
idx = 0
|
||||
indices = np.arange(num_samples)[::10]
|
||||
for i in indices:
|
||||
if y[i] == 5:
|
||||
assert 0
|
||||
else:
|
||||
pbar.update(1)
|
||||
w_features = {}
|
||||
missing_w_features = {}
|
||||
flag = False
|
||||
for mod in range(num_modalities):
|
||||
mod_features = process_by_mod(column_list[mod], x[i, :, mod])
|
||||
if mod_features is None:
|
||||
flag = True
|
||||
break
|
||||
for k, v in mod_features.items():
|
||||
w_features[k] = v
|
||||
if flag:
|
||||
continue
|
||||
data.append(
|
||||
dict(
|
||||
user_id=curr_user_id,
|
||||
session_id=curr_session_id,
|
||||
idx=idx,
|
||||
label=class_dict[y[i]],
|
||||
features=w_features,
|
||||
)
|
||||
)
|
||||
idx += 1
|
||||
pbar.close()
|
||||
return data
|
||||
|
||||
|
||||
def run(path=SLEEPEDF_PATH, out_dir=OUT_DIR, num_examples=1, num_workers=32, seed=0):
|
||||
np.random.seed(seed)
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
info_path = os.path.join(out_dir, "info.json")
|
||||
store_info(info_path)
|
||||
print(f"Saved info to {info_path}")
|
||||
psg_file_paths = glob(os.path.join(path, "*PSG.edf"))
|
||||
ann_file_paths = glob(os.path.join(path, "*Hypnogram.edf"))
|
||||
psg_file_paths.sort()
|
||||
ann_file_paths.sort()
|
||||
file_paths = list(zip(psg_file_paths, ann_file_paths))
|
||||
filtered_2013_file_paths = []
|
||||
for file_path in file_paths:
|
||||
basename = os.path.basename(file_path[0])
|
||||
if basename.startswith("SC40"):
|
||||
filtered_2013_file_paths.append(file_path)
|
||||
elif basename.startswith("SC41"):
|
||||
filtered_2013_file_paths.append(file_path)
|
||||
|
||||
with Pool(processes=num_workers) as pool:
|
||||
for data in pool.imap_unordered(preprocess, filtered_2013_file_paths):
|
||||
if len(data) == 0:
|
||||
continue
|
||||
user_id = data[0]["user_id"]
|
||||
session_id = data[0]["session_id"]
|
||||
dataset = Dataset.from_list(data)
|
||||
test_dir = os.path.join(out_dir, f"{user_id}", f"{session_id}")
|
||||
dataset.save_to_disk(test_dir)
|
||||
print(f"Saved dataset to {test_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(run)
|
||||
49
run.py
Normal file
49
run.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
import yaml
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from glob import glob
|
||||
from fire import Fire
|
||||
|
||||
from core.model import load_models
|
||||
from core.data_loader import DataLoader
|
||||
from core.sensing_agent import SensingAgent
|
||||
|
||||
|
||||
async def run_parallel(config, model_pool):
|
||||
tasks = []
|
||||
# for seed in range(config["num_seeds"]):
|
||||
# users = glob(os.path.join(config["data_path"], "*"))
|
||||
# users = [path for path in users if os.path.isdir(path)]
|
||||
# for user in users:
|
||||
data_loader = DataLoader(config["data_path"], "00", selection_criteria="in_random", num_examples=1)
|
||||
# if not data_loader.is_valid:
|
||||
# continue
|
||||
for sample, examples in data_loader:
|
||||
agent = SensingAgent(
|
||||
name="EEG sensing",
|
||||
model_pool=model_pool,
|
||||
task_info=data_loader.get_task_info(),
|
||||
classes_info=data_loader.get_classes_info(),
|
||||
sensor_info=data_loader.get_sensor_info(),
|
||||
sample=sample,
|
||||
examples=examples,
|
||||
log_path=config["log_path"],
|
||||
)
|
||||
ground_truth = sample["label"]
|
||||
task = asyncio.create_task(agent.solve(sample, examples, ground_truth))
|
||||
tasks.append(task)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def run(config_path):
|
||||
config = yaml.load(open(config_path, "r", encoding="utf-8"), Loader=yaml.SafeLoader)
|
||||
model_pool = load_models(config["models"])
|
||||
asyncio.run(run_parallel(config, model_pool))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(run)
|
||||
16
utils/kill_ollamas.sh
Executable file
16
utils/kill_ollamas.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "Killing all tmux sessions starting with 'ollama'..."
|
||||
|
||||
# Get all tmux sessions starting with "ollama"
|
||||
SESSIONS=$(tmux list-sessions -F '#S' | grep '^ollama')
|
||||
|
||||
if [ -z "$SESSIONS" ]; then
|
||||
echo "No 'ollama' sessions found."
|
||||
else
|
||||
for SESSION in $SESSIONS; do
|
||||
echo "Killing session: $SESSION"
|
||||
tmux kill-session -t "$SESSION"
|
||||
done
|
||||
echo "All 'ollama' sessions have been terminated."
|
||||
fi
|
||||
22
utils/launch_ollamas.sh
Executable file
22
utils/launch_ollamas.sh
Executable file
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
PORTS=(11437 11438 11439 11440 11441 11442 11443 11444)
|
||||
|
||||
for i in "${!PORTS[@]}"; do
|
||||
PORT="${PORTS[$i]}"
|
||||
SESSION="ollama$i"
|
||||
|
||||
# Run a login shell (-l) and then source .bashrc just in case, then start Ollama
|
||||
CMD="bash -lc 'OLLAMA_MODELS=/mnt/sting/hjyoon/projects/llm/ollama OLLAMA_HOST=0.0.0.0:${PORT} ollama serve'"
|
||||
|
||||
if tmux has-session -t "$SESSION" 2>/dev/null; then
|
||||
echo "Session $SESSION exists. Restarting ollama serve..."
|
||||
tmux send-keys -t "$SESSION" C-c
|
||||
sleep 1
|
||||
tmux send-keys -t "$SESSION" "$CMD" C-m
|
||||
else
|
||||
echo "Creating session $SESSION and starting ollama serve on port $PORT..."
|
||||
tmux new-session -d -s "$SESSION" "$CMD"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "All Ollama servers have been started or restarted in tmux sessions."
|
||||
Reference in New Issue
Block a user