From 2dd1b8f0c5b670da2d44d87d6c4938313941829e Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 20 Oct 2022 18:35:49 +0530 Subject: [PATCH] adding key pair dataset (#19765) --- docs/source/en/main_classes/pipelines.mdx | 2 +- src/transformers/pipelines/pt_utils.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/source/en/main_classes/pipelines.mdx b/docs/source/en/main_classes/pipelines.mdx index ef6adc4810..daed2f42dc 100644 --- a/docs/source/en/main_classes/pipelines.mdx +++ b/docs/source/en/main_classes/pipelines.mdx @@ -91,7 +91,7 @@ pipe = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-96 dataset = datasets.load_dataset("superb", name="asr", split="test") # KeyDataset (only *pt*) will simply return the item in the dict returned by the dataset item -# as we're not interested in the *target* part of the dataset. +# as we're not interested in the *target* part of the dataset. For sentence pair use KeyPairDataset for out in tqdm(pipe(KeyDataset(dataset, "file"))): print(out) # {"text": "NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD NIGHT HUSBAND"} diff --git a/src/transformers/pipelines/pt_utils.py b/src/transformers/pipelines/pt_utils.py index 455fdc34f7..a194c155ea 100644 --- a/src/transformers/pipelines/pt_utils.py +++ b/src/transformers/pipelines/pt_utils.py @@ -293,3 +293,16 @@ class KeyDataset(Dataset): def __getitem__(self, i): return self.dataset[i][self.key] + + +class KeyPairDataset(Dataset): + def __init__(self, dataset: Dataset, key1: str, key2: str): + self.dataset = dataset + self.key1 = key1 + self.key2 = key2 + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, i): + return {"text": self.dataset[i][self.key1], "text_pair": self.dataset[i][self.key2]}