implemented placeholders

This commit is contained in:
Hyungjun Yoon
2026-01-09 13:52:03 +09:00
parent ee711d6034
commit 636874f0f8
7 changed files with 79 additions and 1 deletions

View File

@@ -0,0 +1,49 @@
import os
import numpy as np
import datasets
from glob import glob
from fire import Fire
from chronos2 import Chronos2
SleepEDF_PATH = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
CHRONOS2_EMB_PATH = "."
def load_data(data_path):
return None
# def emb_chronos2(data):
# return None
def plot_embs(embs):
return None
def run(
data_path=SleepEDF_PATH,
is_chached=False,
cached_path=CHRONOS2_EMB_PATH,
):
chronos2 = Chronos2()
user_test_paths = glob(os.path.join(data_path, "*", "2")) # session 2
user_test_data = load_data(user_test_paths)
if not is_chached:
embs = []
for data in user_test_data:
emb = chronos2.emb_chronos2(data)
embs.append(emb)
embs = np.array(embs)
np.save(cached_path, embs)
else:
embs = np.load(cached_path)
plot_embs(embs)
# save plot to pdf in the same directory
# tsne speed issue
# tsne visualization issue
if __name__ == "__main__":
Fire(run)

View File

@@ -0,0 +1,26 @@
import os
from fire import Fire
import datasets
SleepEDF_PATH = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
def run(
data_path=SleepEDF_PATH,
user_id=None,
session_id=None,
num_examples=1,
num_workers=32,
seed=0,
):
user_train_path = os.path.join(data_path, user_id, "1")
user_test_path = os.path.join(data_path, user_id, "2")
user_train_dataset = datasets.load_from_disk(user_train_path)
user_test_dataset = datasets.load_from_disk(user_test_path)
user_train_dataset = user_train_dataset.shuffle(seed=seed)
user_test_dataset = user_test_dataset.shuffle(seed=seed)
user_train_dataset = user_train_dataset.select(range(num_examples))
user_test_dataset = user_test_dataset.select(range(num_examples))
pass
if __name__ == "__main__":
Fire(run)

View File

@@ -418,7 +418,7 @@ def preprocess(file_path):
else:
pbar.update(1)
w_features = {}
missing_w_features = {}
raw_data = {}
flag = False
for mod in range(num_modalities):
mod_features = process_by_mod(column_list[mod], x[i, :, mod])
@@ -427,6 +427,8 @@ def preprocess(file_path):
break
for k, v in mod_features.items():
w_features[k] = v
if "EEG" in column_list[mod]:
raw_data[column_list[mod]] = x[i, :, mod]
if flag:
continue
data.append(
@@ -436,6 +438,7 @@ def preprocess(file_path):
idx=idx,
label=class_dict[y[i]],
features=w_features,
data=raw_data,
)
)
idx += 1