[seq2seq] Don't copy self.source in sortishsampler (#5818)

This commit is contained in:
Sam Shleifer
2020-07-17 01:53:25 -04:00
committed by GitHub
parent 2e4624b415
commit e238e3d55a

View File

@@ -144,16 +144,9 @@ class SummarizationDataset(Dataset):
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
return batch
@property
def src_lens(self): # Can delete?
return lmap(len, self.source)
@property
def tgt_lens(self):
return lmap(len, self.target)
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.source, batch_size)
lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source]
return SortishSampler(lens, batch_size)
class SortishSampler(Sampler):
@@ -163,7 +156,7 @@ class SortishSampler(Sampler):
self.data, self.bs = data, batch_size
def key(self, i):
return len(self.data[i])
return self.data[i]
def __len__(self) -> int:
return len(self.data)