implemented placeholders
This commit is contained in:
49
analysis/user_similarity/chronos2/gen_plot.py
Normal file
49
analysis/user_similarity/chronos2/gen_plot.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import datasets
|
||||||
|
from glob import glob
|
||||||
|
from fire import Fire
|
||||||
|
from chronos2 import Chronos2
|
||||||
|
|
||||||
|
SleepEDF_PATH = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||||
|
CHRONOS2_EMB_PATH = "."
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(data_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# def emb_chronos2(data):
|
||||||
|
# return None
|
||||||
|
|
||||||
|
def plot_embs(embs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def run(
|
||||||
|
data_path=SleepEDF_PATH,
|
||||||
|
is_chached=False,
|
||||||
|
cached_path=CHRONOS2_EMB_PATH,
|
||||||
|
):
|
||||||
|
chronos2 = Chronos2()
|
||||||
|
user_test_paths = glob(os.path.join(data_path, "*", "2")) # session 2
|
||||||
|
user_test_data = load_data(user_test_paths)
|
||||||
|
|
||||||
|
if not is_chached:
|
||||||
|
embs = []
|
||||||
|
for data in user_test_data:
|
||||||
|
emb = chronos2.emb_chronos2(data)
|
||||||
|
embs.append(emb)
|
||||||
|
embs = np.array(embs)
|
||||||
|
np.save(cached_path, embs)
|
||||||
|
else:
|
||||||
|
embs = np.load(cached_path)
|
||||||
|
|
||||||
|
plot_embs(embs)
|
||||||
|
# save plot to pdf in the same directory
|
||||||
|
# tsne speed issue
|
||||||
|
# tsne visualization issue
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Fire(run)
|
||||||
26
analysis/user_similarity/chronos2/simple_user_classifier.py
Normal file
26
analysis/user_similarity/chronos2/simple_user_classifier.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import os
|
||||||
|
from fire import Fire
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
SleepEDF_PATH = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
||||||
|
|
||||||
|
def run(
|
||||||
|
data_path=SleepEDF_PATH,
|
||||||
|
user_id=None,
|
||||||
|
session_id=None,
|
||||||
|
num_examples=1,
|
||||||
|
num_workers=32,
|
||||||
|
seed=0,
|
||||||
|
):
|
||||||
|
user_train_path = os.path.join(data_path, user_id, "1")
|
||||||
|
user_test_path = os.path.join(data_path, user_id, "2")
|
||||||
|
user_train_dataset = datasets.load_from_disk(user_train_path)
|
||||||
|
user_test_dataset = datasets.load_from_disk(user_test_path)
|
||||||
|
user_train_dataset = user_train_dataset.shuffle(seed=seed)
|
||||||
|
user_test_dataset = user_test_dataset.shuffle(seed=seed)
|
||||||
|
user_train_dataset = user_train_dataset.select(range(num_examples))
|
||||||
|
user_test_dataset = user_test_dataset.select(range(num_examples))
|
||||||
|
pass
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Fire(run)
|
||||||
0
analysis/user_similarity/labram/gen_plot.py
Normal file
0
analysis/user_similarity/labram/gen_plot.py
Normal file
0
analysis/user_similarity/sbert/gen_plot.py
Normal file
0
analysis/user_similarity/sbert/gen_plot.py
Normal file
@@ -418,7 +418,7 @@ def preprocess(file_path):
|
|||||||
else:
|
else:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
w_features = {}
|
w_features = {}
|
||||||
missing_w_features = {}
|
raw_data = {}
|
||||||
flag = False
|
flag = False
|
||||||
for mod in range(num_modalities):
|
for mod in range(num_modalities):
|
||||||
mod_features = process_by_mod(column_list[mod], x[i, :, mod])
|
mod_features = process_by_mod(column_list[mod], x[i, :, mod])
|
||||||
@@ -427,6 +427,8 @@ def preprocess(file_path):
|
|||||||
break
|
break
|
||||||
for k, v in mod_features.items():
|
for k, v in mod_features.items():
|
||||||
w_features[k] = v
|
w_features[k] = v
|
||||||
|
if "EEG" in column_list[mod]:
|
||||||
|
raw_data[column_list[mod]] = x[i, :, mod]
|
||||||
if flag:
|
if flag:
|
||||||
continue
|
continue
|
||||||
data.append(
|
data.append(
|
||||||
@@ -436,6 +438,7 @@ def preprocess(file_path):
|
|||||||
idx=idx,
|
idx=idx,
|
||||||
label=class_dict[y[i]],
|
label=class_dict[y[i]],
|
||||||
features=w_features,
|
features=w_features,
|
||||||
|
data=raw_data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user