201 lines
7.9 KiB
Python
201 lines
7.9 KiB
Python
import os
|
|
import json
|
|
import datasets
|
|
import numpy as np
|
|
from glob import glob
|
|
from typing import Optional, TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from .embedding_index import EmbeddingIndex
|
|
|
|
|
|
class DataLoader:
|
|
def __init__(
|
|
self,
|
|
data_path,
|
|
user_id,
|
|
selection_criteria="out_random",
|
|
num_examples=1,
|
|
embedding_index: Optional["EmbeddingIndex"] = None,
|
|
):
|
|
self.is_valid = False
|
|
self.embedding_index = embedding_index
|
|
self.data_path = data_path
|
|
|
|
if not os.path.exists(os.path.join(data_path, "info.json")):
|
|
return
|
|
if not os.path.exists(os.path.join(data_path, f"{user_id}", "1")):
|
|
return
|
|
if not os.path.exists(os.path.join(data_path, f"{user_id}", "2")):
|
|
return
|
|
|
|
if selection_criteria in ["out_similar", "in_similar"] and embedding_index is None:
|
|
print(f"[WARNING] {selection_criteria} requires embedding_index, falling back to random")
|
|
selection_criteria = "out_random" if "out" in selection_criteria else "in_random"
|
|
|
|
self.metadata = json.load(open(os.path.join(data_path, "info.json"), "r", encoding="utf-8"))
|
|
self.test_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user_id}", "2"))
|
|
self.example_dataset = datasets.Dataset.from_list([])
|
|
users = glob(os.path.join(data_path, "*"))
|
|
users = [path.split("/")[-1] for path in users]
|
|
if "info.json" in users:
|
|
users.remove("info.json")
|
|
for user in users:
|
|
user_dataset = datasets.load_from_disk(os.path.join(data_path, f"{user}", "1"))
|
|
self.example_dataset = datasets.concatenate_datasets([self.example_dataset, user_dataset])
|
|
|
|
self.test_dataset = self.test_dataset.shuffle(seed=0)
|
|
self.example_dataset = self.example_dataset.shuffle(seed=0)
|
|
|
|
self.user_id = user_id
|
|
self.selection_criteria = selection_criteria
|
|
self.num_examples = num_examples
|
|
|
|
self.classes = sorted(list(self.metadata["class"].keys()))
|
|
|
|
# Build lookup index for fast example retrieval: (user_id, idx) -> dataset_index
|
|
self._example_lookup = {}
|
|
for i, example in enumerate(self.example_dataset):
|
|
key = (str(example["user_id"]), int(example["idx"]))
|
|
self._example_lookup[key] = i
|
|
|
|
if selection_criteria in ["out_similar", "in_similar"]:
|
|
self.selected_examples = None
|
|
else:
|
|
self.selected_examples = self.sample_examples()
|
|
if self.selected_examples is None:
|
|
return
|
|
|
|
self.is_valid = True
|
|
|
|
def __len__(self):
|
|
return len(self.test_dataset)
|
|
|
|
def __getitem__(self, idx):
|
|
sample = self.test_dataset[idx]
|
|
if self.selection_criteria in ["out_similar", "in_similar"]:
|
|
examples = self.sample_similar_examples(sample)
|
|
else:
|
|
examples = self.selected_examples
|
|
return sample, examples
|
|
|
|
def __iter__(self):
|
|
for sample in self.test_dataset:
|
|
if self.selection_criteria in ["out_similar", "in_similar"]:
|
|
examples = self.sample_similar_examples(sample)
|
|
else:
|
|
examples = self.selected_examples
|
|
yield sample, examples
|
|
|
|
def sample_similar_examples(self, sample):
|
|
"""
|
|
Sample examples based on embedding similarity to the given sample.
|
|
|
|
For each class, finds the most similar example from the example dataset
|
|
using Chronos-2 embeddings.
|
|
|
|
Args:
|
|
sample: The test sample to find similar examples for
|
|
|
|
Returns:
|
|
HuggingFace Dataset containing similar examples (one per class)
|
|
"""
|
|
if self.embedding_index is None:
|
|
return self.sample_examples()
|
|
|
|
query_embedding = self.embedding_index.get_embedding_by_key(
|
|
user_id=str(sample["user_id"]),
|
|
session_id=str(sample["session_id"]),
|
|
idx=int(sample["idx"]),
|
|
)
|
|
|
|
if query_embedding is None:
|
|
print(f"[WARNING] No embedding found for sample {sample['user_id']}/{sample['session_id']}/{sample['idx']}")
|
|
return self.sample_examples()
|
|
|
|
if self.selection_criteria == "out_similar":
|
|
exclude_user = str(self.user_id)
|
|
include_user = None
|
|
else:
|
|
exclude_user = None
|
|
include_user = str(self.user_id)
|
|
|
|
similar_per_class = self.embedding_index.find_similar_per_class(
|
|
query_embedding=query_embedding,
|
|
classes=self.classes,
|
|
k_per_class=self.num_examples,
|
|
exclude_user=exclude_user,
|
|
include_user=include_user,
|
|
filter_session="1",
|
|
)
|
|
|
|
example_list = []
|
|
for cls, similar_samples in similar_per_class.items():
|
|
for global_idx, similarity, metadata in similar_samples:
|
|
example = self._find_example_by_metadata(metadata)
|
|
if example is not None:
|
|
example_list.append(example)
|
|
|
|
if len(example_list) == 0:
|
|
print(f"[WARNING] No similar examples found, falling back to random")
|
|
return self.sample_examples()
|
|
|
|
return datasets.Dataset.from_list(example_list)
|
|
|
|
def _find_example_by_metadata(self, metadata):
|
|
"""Find an example in the example_dataset by its metadata using O(1) lookup."""
|
|
user_id = str(metadata["user_id"])
|
|
idx = int(metadata["idx"])
|
|
|
|
key = (user_id, idx)
|
|
if key not in self._example_lookup:
|
|
return None
|
|
|
|
dataset_idx = self._example_lookup[key]
|
|
example = self.example_dataset[dataset_idx]
|
|
return {
|
|
"user_id": example["user_id"],
|
|
"session_id": example["session_id"],
|
|
"idx": example["idx"],
|
|
"label": example["label"],
|
|
"features": example["features"],
|
|
"data": example.get("data", {}),
|
|
}
|
|
|
|
def sample_examples(self):
|
|
example_dataset = datasets.Dataset.from_list([])
|
|
if self.selection_criteria == "out_random":
|
|
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] != user_id)
|
|
for c in self.classes:
|
|
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
|
|
if len(class_dataset) < self.num_examples:
|
|
return None
|
|
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
|
|
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
|
|
elif self.selection_criteria == "in_random":
|
|
filtered_example_dataset = self.example_dataset.filter(lambda x, user_id=self.user_id: x["user_id"] == user_id)
|
|
for c in self.classes:
|
|
class_dataset = filtered_example_dataset.filter(lambda x, c_=c: x["label"] == c_)
|
|
if len(class_dataset) < self.num_examples:
|
|
return None
|
|
sampled_examples = class_dataset.select(np.random.choice(len(class_dataset), self.num_examples, replace=False))
|
|
example_dataset = datasets.concatenate_datasets([example_dataset, sampled_examples])
|
|
return example_dataset
|
|
|
|
def get_metadata(self):
|
|
return self.metadata
|
|
|
|
def get_sensor_info(self):
|
|
return self.metadata["feature"]
|
|
|
|
def get_task_info(self):
|
|
task_info = f"**Task**:\n{self.metadata['task']}\n\n"
|
|
classes_info = [f" - {k}: {v}" for k, v in self.metadata["class"].items()]
|
|
classes_info = "\n".join(classes_info)
|
|
task_info += f"**Classes**:\n{classes_info}"
|
|
return task_info
|
|
|
|
def get_classes_info(self):
|
|
classes_info = [k for k in self.metadata["class"].keys()]
|
|
return classes_info
|