adding key pair dataset (#19765)

This commit is contained in:
Rohit Gupta
2022-10-20 18:35:49 +05:30
committed by GitHub
parent 17d7aec895
commit 2dd1b8f0c5
2 changed files with 14 additions and 1 deletions

View File

@@ -91,7 +91,7 @@ pipe = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-96
dataset = datasets.load_dataset("superb", name="asr", split="test") dataset = datasets.load_dataset("superb", name="asr", split="test")
# KeyDataset (only *pt*) will simply return the item in the dict returned by the dataset item # KeyDataset (only *pt*) will simply return the item in the dict returned by the dataset item
# as we're not interested in the *target* part of the dataset. # as we're not interested in the *target* part of the dataset. For sentence pair use KeyPairDataset
for out in tqdm(pipe(KeyDataset(dataset, "file"))): for out in tqdm(pipe(KeyDataset(dataset, "file"))):
print(out) print(out)
# {"text": "NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD NIGHT HUSBAND"} # {"text": "NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD NIGHT HUSBAND"}

View File

@@ -293,3 +293,16 @@ class KeyDataset(Dataset):
def __getitem__(self, i): def __getitem__(self, i):
return self.dataset[i][self.key] return self.dataset[i][self.key]
class KeyPairDataset(Dataset):
def __init__(self, dataset: Dataset, key1: str, key2: str):
self.dataset = dataset
self.key1 = key1
self.key2 = key2
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
return {"text": self.dataset[i][self.key1], "text_pair": self.dataset[i][self.key2]}