implemented placeholders
This commit is contained in:
49
analysis/user_similarity/chronos2/gen_plot.py
Normal file
49
analysis/user_similarity/chronos2/gen_plot.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import datasets
|
||||
from glob import glob
|
||||
from fire import Fire
|
||||
from chronos2 import Chronos2
|
||||
|
||||
SleepEDF_PATH = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
CHRONOS2_EMB_PATH = "."
|
||||
|
||||
|
||||
def load_data(data_path):
|
||||
return None
|
||||
|
||||
|
||||
# def emb_chronos2(data):
|
||||
# return None
|
||||
|
||||
def plot_embs(embs):
|
||||
return None
|
||||
|
||||
|
||||
def run(
|
||||
data_path=SleepEDF_PATH,
|
||||
is_chached=False,
|
||||
cached_path=CHRONOS2_EMB_PATH,
|
||||
):
|
||||
chronos2 = Chronos2()
|
||||
user_test_paths = glob(os.path.join(data_path, "*", "2")) # session 2
|
||||
user_test_data = load_data(user_test_paths)
|
||||
|
||||
if not is_chached:
|
||||
embs = []
|
||||
for data in user_test_data:
|
||||
emb = chronos2.emb_chronos2(data)
|
||||
embs.append(emb)
|
||||
embs = np.array(embs)
|
||||
np.save(cached_path, embs)
|
||||
else:
|
||||
embs = np.load(cached_path)
|
||||
|
||||
plot_embs(embs)
|
||||
# save plot to pdf in the same directory
|
||||
# tsne speed issue
|
||||
# tsne visualization issue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(run)
|
||||
26
analysis/user_similarity/chronos2/simple_user_classifier.py
Normal file
26
analysis/user_similarity/chronos2/simple_user_classifier.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import os
|
||||
from fire import Fire
|
||||
import datasets
|
||||
|
||||
SleepEDF_PATH = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||
|
||||
def run(
|
||||
data_path=SleepEDF_PATH,
|
||||
user_id=None,
|
||||
session_id=None,
|
||||
num_examples=1,
|
||||
num_workers=32,
|
||||
seed=0,
|
||||
):
|
||||
user_train_path = os.path.join(data_path, user_id, "1")
|
||||
user_test_path = os.path.join(data_path, user_id, "2")
|
||||
user_train_dataset = datasets.load_from_disk(user_train_path)
|
||||
user_test_dataset = datasets.load_from_disk(user_test_path)
|
||||
user_train_dataset = user_train_dataset.shuffle(seed=seed)
|
||||
user_test_dataset = user_test_dataset.shuffle(seed=seed)
|
||||
user_train_dataset = user_train_dataset.select(range(num_examples))
|
||||
user_test_dataset = user_test_dataset.select(range(num_examples))
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(run)
|
||||
0
analysis/user_similarity/labram/gen_plot.py
Normal file
0
analysis/user_similarity/labram/gen_plot.py
Normal file
0
analysis/user_similarity/sbert/gen_plot.py
Normal file
0
analysis/user_similarity/sbert/gen_plot.py
Normal file
@@ -418,7 +418,7 @@ def preprocess(file_path):
|
||||
else:
|
||||
pbar.update(1)
|
||||
w_features = {}
|
||||
missing_w_features = {}
|
||||
raw_data = {}
|
||||
flag = False
|
||||
for mod in range(num_modalities):
|
||||
mod_features = process_by_mod(column_list[mod], x[i, :, mod])
|
||||
@@ -427,6 +427,8 @@ def preprocess(file_path):
|
||||
break
|
||||
for k, v in mod_features.items():
|
||||
w_features[k] = v
|
||||
if "EEG" in column_list[mod]:
|
||||
raw_data[column_list[mod]] = x[i, :, mod]
|
||||
if flag:
|
||||
continue
|
||||
data.append(
|
||||
@@ -436,6 +438,7 @@ def preprocess(file_path):
|
||||
idx=idx,
|
||||
label=class_dict[y[i]],
|
||||
features=w_features,
|
||||
data=raw_data,
|
||||
)
|
||||
)
|
||||
idx += 1
|
||||
|
||||
Reference in New Issue
Block a user