Files
tsllm_personalization_icl/analysis/user_similarity/chronos2/simple_user_classifier.py
2026-01-09 13:52:03 +09:00

26 lines
832 B
Python

import os
from fire import Fire
import datasets
SleepEDF_PATH = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
def run(
data_path=SleepEDF_PATH,
user_id=None,
session_id=None,
num_examples=1,
num_workers=32,
seed=0,
):
user_train_path = os.path.join(data_path, user_id, "1")
user_test_path = os.path.join(data_path, user_id, "2")
user_train_dataset = datasets.load_from_disk(user_train_path)
user_test_dataset = datasets.load_from_disk(user_test_path)
user_train_dataset = user_train_dataset.shuffle(seed=seed)
user_test_dataset = user_test_dataset.shuffle(seed=seed)
user_train_dataset = user_train_dataset.select(range(num_examples))
user_test_dataset = user_test_dataset.select(range(num_examples))
pass
if __name__ == "__main__":
Fire(run)