[seq2seq] Don't copy self.source in sortishsampler (#5818)
This commit is contained in:
@@ -144,16 +144,9 @@ class SummarizationDataset(Dataset):
|
|||||||
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
|
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@property
|
|
||||||
def src_lens(self): # Can delete?
|
|
||||||
return lmap(len, self.source)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tgt_lens(self):
|
|
||||||
return lmap(len, self.target)
|
|
||||||
|
|
||||||
def make_sortish_sampler(self, batch_size):
|
def make_sortish_sampler(self, batch_size):
|
||||||
return SortishSampler(self.source, batch_size)
|
lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source]
|
||||||
|
return SortishSampler(lens, batch_size)
|
||||||
|
|
||||||
|
|
||||||
class SortishSampler(Sampler):
|
class SortishSampler(Sampler):
|
||||||
@@ -163,7 +156,7 @@ class SortishSampler(Sampler):
|
|||||||
self.data, self.bs = data, batch_size
|
self.data, self.bs = data, batch_size
|
||||||
|
|
||||||
def key(self, i):
|
def key(self, i):
|
||||||
return len(self.data[i])
|
return self.data[i]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|||||||
Reference in New Issue
Block a user