[seq2seq] Don't copy self.source in sortishsampler (#5818)
This commit is contained in:
@@ -144,16 +144,9 @@ class SummarizationDataset(Dataset):
|
||||
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
|
||||
return batch
|
||||
|
||||
@property
|
||||
def src_lens(self): # Can delete?
|
||||
return lmap(len, self.source)
|
||||
|
||||
@property
|
||||
def tgt_lens(self):
|
||||
return lmap(len, self.target)
|
||||
|
||||
def make_sortish_sampler(self, batch_size):
|
||||
return SortishSampler(self.source, batch_size)
|
||||
lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source]
|
||||
return SortishSampler(lens, batch_size)
|
||||
|
||||
|
||||
class SortishSampler(Sampler):
|
||||
@@ -163,7 +156,7 @@ class SortishSampler(Sampler):
|
||||
self.data, self.bs = data, batch_size
|
||||
|
||||
def key(self, i):
|
||||
return len(self.data[i])
|
||||
return self.data[i]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
Reference in New Issue
Block a user