26 lines
832 B
Python
26 lines
832 B
Python
import os
|
|
from fire import Fire
|
|
import datasets
|
|
|
|
SleepEDF_PATH = "/mnt/sting/hjyoon/projects/tsllm_personalization_icl/data/SleepEDF"
|
|
|
|
def run(
|
|
data_path=SleepEDF_PATH,
|
|
user_id=None,
|
|
session_id=None,
|
|
num_examples=1,
|
|
num_workers=32,
|
|
seed=0,
|
|
):
|
|
user_train_path = os.path.join(data_path, user_id, "1")
|
|
user_test_path = os.path.join(data_path, user_id, "2")
|
|
user_train_dataset = datasets.load_from_disk(user_train_path)
|
|
user_test_dataset = datasets.load_from_disk(user_test_path)
|
|
user_train_dataset = user_train_dataset.shuffle(seed=seed)
|
|
user_test_dataset = user_test_dataset.shuffle(seed=seed)
|
|
user_train_dataset = user_train_dataset.select(range(num_examples))
|
|
user_test_dataset = user_test_dataset.select(range(num_examples))
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
Fire(run) |