Files
tsllm_personalization_icl/core/data_loader.py
2026-01-19 19:56:35 +09:00

201 lines
7.9 KiB
Python

import os
import json
import datasets
import numpy as np
from glob import glob
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .embedding_index import EmbeddingIndex
class DataLoader:
def __init__(
self,
data_path,
user_id,
selection_criteria="out_random",
num_examples=1,
embedding_index: Optional["EmbeddingIndex"] = None,
):
self.is_valid = False
self.embedding_index = embedding_index
self.data_path = data_path
if not os.path.exists(os.path.join(data_path, "info.json")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
return
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
return
if selection_criteria in ["out_similar", "in_similar"] and embedding_index is None:
print(f"[WARNING] {selection_criteria} requires embedding_index, falling back to random")
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
self.example_dataset = datasets.Dataset.from_list([])
users = glob(os.path.join(data_path, "*"))
users = [path.split("/")[-1] for path in users]
if "info.json" in users:
users.remove("info.json")
for user in users:
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
self.test_dataset = self.test_dataset.shuffle(seed=0)
self.example_dataset = self.example_dataset.shuffle(seed=0)
self.user_id = user_id
self.selection_criteria = selection_criteria
self.num_examples = num_examples
self.classes = sorted(list(self.metadata["class"].keys()))
# Build lookup index for fast example retrieval: (user_id, idx) -> dataset_index
self._example_lookup = {}
for i, example in enumerate(self.example_dataset):
key = (str(example["user_id"]), int(example["idx"]))
self._example_lookup[key] = i
if selection_criteria in ["out_similar", "in_similar"]:
self.selected_examples = None
else:
self.selected_examples = self.sample_examples()
if self.selected_examples is None:
return
self.is_valid = True
def __len__(self):
return len(self.test_dataset)
def __getitem__(self, idx):
sample = self.test_dataset[idx]
if self.selection_criteria in ["out_similar", "in_similar"]:
examples = self.sample_similar_examples(sample)
else:
examples = self.selected_examples
return sample, examples
def __iter__(self):
for sample in self.test_dataset:
if self.selection_criteria in ["out_similar", "in_similar"]:
examples = self.sample_similar_examples(sample)
else:
examples = self.selected_examples
yield sample, examples
def sample_similar_examples(self, sample):
"""
Sample examples based on embedding similarity to the given sample.
For each class, finds the most similar example from the example dataset
using Chronos-2 embeddings.
Args:
sample: The test sample to find similar examples for
Returns:
HuggingFace Dataset containing similar examples (one per class)
"""
if self.embedding_index is None:
return self.sample_examples()
query_embedding = self.embedding_index.get_embedding_by_key(
user_id=str(sample["user_id"]),
session_id=str(sample["session_id"]),
idx=int(sample["idx"]),
)
if query_embedding is None:
print(f"[WARNING] No embedding found for sample {sample['user_id']}/{sample['session_id']}/{sample['idx']}")
return self.sample_examples()
if self.selection_criteria == "out_similar":
exclude_user = str(self.user_id)
include_user = None
else:
exclude_user = None
include_user = str(self.user_id)
similar_per_class = self.embedding_index.find_similar_per_class(
query_embedding=query_embedding,
classes=self.classes,
k_per_class=self.num_examples,
exclude_user=exclude_user,
include_user=include_user,
filter_session="1",
)
example_list = []
for cls, similar_samples in similar_per_class.items():
for global_idx, similarity, metadata in similar_samples:
example = self._find_example_by_metadata(metadata)
if example is not None:
example_list.append(example)
if len(example_list) == 0:
print(f"[WARNING] No similar examples found, falling back to random")
return self.sample_examples()
return datasets.Dataset.from_list(example_list)
def _find_example_by_metadata(self, metadata):
"""Find an example in the example_dataset by its metadata using O(1) lookup."""
user_id = str(metadata["user_id"])
idx = int(metadata["idx"])
key = (user_id, idx)
if key not in self._example_lookup:
return None
dataset_idx = self._example_lookup[key]
example = self.example_dataset[dataset_idx]
return {
"user_id": example["user_id"],
"session_id": example["session_id"],
"idx": example["idx"],
"label": example["label"],
"features": example["features"],
"data": example.get("data", {}),
}
def sample_examples(self):
example_dataset = datasets.Dataset.from_list([])
if self.selection_criteria == "out_random":
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] != user_id)
for c in self.classes:
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
if len(class_dataset) < self.num_examples:
return None
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
elif self.selection_criteria == "in_random":
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] == user_id)
for c in self.classes:
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
if len(class_dataset) < self.num_examples:
return None
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
return example_dataset
def get_metadata(self):
return self.metadata
def get_sensor_info(self):
return self.metadata["feature"]
def get_task_info(self):
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
classes_info = "\n".join(classes_info)
task_info += f"**Classes**:\n{classes_info}"
return task_info
def get_classes_info(self):
classes_info = [k for k in self.metadata["class"].keys()]
return classes_info